diff --git a/firstpass_learner.py b/firstpass_learner.py index 2debfc3806a89063d239a0943c5cc1c09e39dcb2..56939513692f0966609edfe140103308d5671209 100644 --- a/firstpass_learner.py +++ b/firstpass_learner.py @@ -1,4 +1,4 @@ -from typing import Sequence +from typing import Mapping, Sequence, Optional, Tuple import numpy as np import z3 @@ -39,6 +39,7 @@ class FirstpassLearner(Z3LearnerBase): super().__init__(state_dim=2, perc_dim=2) self._in_shape_pred: z3.ExprRef = z3.BoolVal(True) + self._part_affine_map: Mapping[int, Tuple] = {} self.set_grammar(None) @property @@ -57,6 +58,8 @@ class FirstpassLearner(Z3LearnerBase): intercept_arr = np.array(z3.RealVector("b%d" % idx, self.perc_dim)) radius = z3.Real("r%d" % idx) + self._part_affine_map[idx] = (coeff_arr, intercept_arr, radius) + transformed_perc_seq = affine_transform(self._state_vars, self._percept_vars, coeff_arr, intercept_arr) return z3.And( l2_norm(*transformed_perc_seq) <= radius, @@ -83,6 +86,29 @@ class FirstpassLearner(Z3LearnerBase): def add_implication_examples(self, *args) -> None: return super().add_implication_examples(*args) + def learn(self) -> Optional[Tuple]: + res = self._solver.check() + if res == z3.sat: + z3_coeff_arr, z3_intercept_arr, z3_radius = self._part_affine_map[0] # FIXME + m = self._solver.model() + coeff_arr = np.vectorize(extract_from_z3_model(m))(z3_coeff_arr) + intercept_arr = np.vectorize(extract_from_z3_model(m))(z3_intercept_arr) + radius = extract_from_z3_model(m)(z3_radius) + return coeff_arr, intercept_arr, radius + elif res == z3.unsat: + print("(unsat)") + return None + else: + print("(unknown)") + return None + + +def extract_from_z3_model(m: z3.ModelRef): + def to_float(var): + x2 = m[var].as_fraction() + return float(x2.numerator) / float(x2.denominator) + return to_float + def test_affine_transform(): state_vars = z3.RealVector("x", 3)