From 8fa014c91f97b8056336a8cbeb366c48b5332e31 Mon Sep 17 00:00:00 2001
From: "Hsieh, Chiao" <chsieh16@illinois.edu>
Date: Mon, 29 Nov 2021 17:17:13 -0600
Subject: [PATCH] Convert z3 vals to float numpy array

---
 firstpass_learner.py | 28 +++++++++++++++++++++++++++-
 1 file changed, 27 insertions(+), 1 deletion(-)

diff --git a/firstpass_learner.py b/firstpass_learner.py
index 2debfc3..5693951 100644
--- a/firstpass_learner.py
+++ b/firstpass_learner.py
@@ -1,4 +1,4 @@
-from typing import Sequence
+from typing import Mapping, Sequence, Optional, Tuple
 
 import numpy as np
 import z3
@@ -39,6 +39,7 @@ class FirstpassLearner(Z3LearnerBase):
         super().__init__(state_dim=2, perc_dim=2)
 
         self._in_shape_pred: z3.ExprRef = z3.BoolVal(True)
+        self._part_affine_map: Mapping[int, Tuple] = {}
         self.set_grammar(None)
 
     @property
@@ -57,6 +58,8 @@ class FirstpassLearner(Z3LearnerBase):
         intercept_arr = np.array(z3.RealVector("b%d" % idx, self.perc_dim))
         radius = z3.Real("r%d" % idx)
 
+        self._part_affine_map[idx] = (coeff_arr, intercept_arr, radius)
+
         transformed_perc_seq = affine_transform(self._state_vars, self._percept_vars, coeff_arr, intercept_arr)
         return z3.And(
             l2_norm(*transformed_perc_seq)  <= radius,
@@ -83,6 +86,29 @@ class FirstpassLearner(Z3LearnerBase):
     def add_implication_examples(self, *args) -> None:
         return super().add_implication_examples(*args)
 
+    def learn(self) -> Optional[Tuple]:
+        res = self._solver.check()
+        if res == z3.sat:
+            z3_coeff_arr, z3_intercept_arr, z3_radius = self._part_affine_map[0]  # FIXME
+            m = self._solver.model()
+            coeff_arr = np.vectorize(extract_from_z3_model(m))(z3_coeff_arr)
+            intercept_arr = np.vectorize(extract_from_z3_model(m))(z3_intercept_arr)
+            radius = extract_from_z3_model(m)(z3_radius)
+            return coeff_arr, intercept_arr, radius
+        elif res == z3.unsat:
+            print("(unsat)")
+            return None
+        else:
+            print("(unknown)")
+            return None
+
+
+def extract_from_z3_model(m: z3.ModelRef):
+    def to_float(var):
+        x2 = m[var].as_fraction()
+        return float(x2.numerator) / float(x2.denominator)
+    return to_float
+
 
 def test_affine_transform():
     state_vars = z3.RealVector("x", 3)
-- 
GitLab