From b8382e5c82c785ee8f29acd16d16ab91faaaaadd Mon Sep 17 00:00:00 2001
From: "Hsieh, Chiao" <chsieh16@illinois.edu>
Date: Mon, 29 Nov 2021 16:33:30 -0600
Subject: [PATCH] Use only circle for shape

---
 firstpass_learner.py | 44 +++++++++++++++++++++++++++++---------------
 1 file changed, 29 insertions(+), 15 deletions(-)

diff --git a/firstpass_learner.py b/firstpass_learner.py
index b0c5679..2debfc3 100644
--- a/firstpass_learner.py
+++ b/firstpass_learner.py
@@ -6,12 +6,13 @@ import z3
 from learner_base import Z3LearnerBase
 
 
-def affine_transform(exprs: Sequence[z3.ExprRef],
+def affine_transform(state_vars: Sequence[z3.ExprRef], perc_vars: Sequence[z3.ExprRef],
         coeff: np.ndarray, intercept: np.ndarray) -> Sequence[z3.ExprRef]:
     assert len(coeff.shape) == 2  # Matrix is a 2D array
-    assert coeff.shape == (intercept.shape[0], len(exprs))
+    assert coeff.shape == (intercept.shape[0], len(state_vars))
+    assert intercept.shape[0] == len(perc_vars)
 
-    return [z3.Sum(*(coeff[i][j]*exprs[j] for j in range(coeff.shape[1])), intercept[i])
+    return [perc_vars[i] - z3.Sum(*(coeff[i][j]*state_vars[j] for j in range(coeff.shape[1])), intercept[i])
             for i in range(coeff.shape[0])]
 
 def abs_expr(expr: z3.ExprRef) -> z3.ExprRef:
@@ -56,25 +57,27 @@ class FirstpassLearner(Z3LearnerBase):
         intercept_arr = np.array(z3.RealVector("b%d" % idx, self.perc_dim))
         radius = z3.Real("r%d" % idx)
 
-        transformed_perc_seq = affine_transform(self._state_vars, coeff_arr, intercept_arr)
+        transformed_perc_seq = affine_transform(self._state_vars, self._percept_vars, coeff_arr, intercept_arr)
         return z3.And(
-            z3.AtLeast(*shape_sel, 1),
-            z3.Implies(shape_sel[0], l1_norm(*transformed_perc_seq)  <= radius),
-            z3.Implies(shape_sel[1], l2_norm(*transformed_perc_seq)  <= radius),
-            z3.Implies(shape_sel[2], loo_norm(*transformed_perc_seq) <= radius),
+            l2_norm(*transformed_perc_seq)  <= radius,
+            # z3.AtLeast(*shape_sel, 1),
+            # z3.Implies(shape_sel[0], l1_norm(*transformed_perc_seq)  <= radius),
+            # z3.Implies(shape_sel[1], l2_norm(*transformed_perc_seq)  <= radius),
+            # z3.Implies(shape_sel[2], loo_norm(*transformed_perc_seq) <= radius),
         )
         
     def add_positive_examples(self, *vals) -> None:
         assert all(len(val) == self.state_dim+self.perc_dim for val in vals)
-        
         self._solver.add(*(
-            z3.substitute_vars(self._in_shape_pred, *val) for val in vals
+            z3.substitute_vars(self._in_shape_pred,
+                               *(z3.RealVal(v) for v in val)) for val in vals
         ))
 
     def add_negative_examples(self, *vals) -> None:
         assert all(len(val) == self.state_dim+self.perc_dim for val in vals)
         self._solver.add(*(
-            z3.Not(z3.substitute_vars(self._in_shape_pred, *val)) for val in vals
+            z3.Not(z3.substitute_vars(self._in_shape_pred,
+                                      *(z3.RealVal(v) for v in val))) for val in vals
         ))
 
     def add_implication_examples(self, *args) -> None:
@@ -104,11 +107,22 @@ def test_norms():
 
 def test_firstpass_learner():
     learner = FirstpassLearner()
-    vals = [z3.RealVal(0.0)]*learner.state_dim + [z3.RealVal(0.5)]*learner.perc_dim
-    learner.add_positive_examples(vals)
+    pos_examples = [
+        (10.0, -10.0, 11.0, -11.0),
+        (10.0, -10.0, 9.0, -9.0),
+        (1.0, 1.0, 1.5, 2.3),
+    ]
+    learner.add_positive_examples(*pos_examples)
+    neg_examples = [
+        (1.0, 1.0, 0.5, 0.5),
+        (1.0, 1.0, 1.5, 1.5),
+        (9.0, 9.0, 5.0, 5.0),
+    ]
+    learner.add_negative_examples(*neg_examples)
+    print(learner._solver.assertions())
     print(learner.learn())
 
 if __name__ == "__main__":
-    test_affine_transform()
-    test_norms()
+    # test_affine_transform()
+    # test_norms()
     test_firstpass_learner()
-- 
GitLab