From b4bcb5e62a4108c92355f1261b4f4387321a7c25 Mon Sep 17 00:00:00 2001
From: "Hsieh, Chiao" <chsieh16@illinois.edu>
Date: Mon, 29 Nov 2021 15:38:53 -0600
Subject: [PATCH] Add SMT encoding for learner of firstpass example

---
 firstpass_learner.py | 114 +++++++++++++++++++++++++++++++++++++++++++
 learner_base.py      |  26 ++++++++++
 2 files changed, 140 insertions(+)
 create mode 100644 firstpass_learner.py

diff --git a/firstpass_learner.py b/firstpass_learner.py
new file mode 100644
index 0000000..b0c5679
--- /dev/null
+++ b/firstpass_learner.py
@@ -0,0 +1,114 @@
+from typing import Sequence
+
+import numpy as np
+import z3
+
+from learner_base import Z3LearnerBase
+
+
+def affine_transform(exprs: Sequence[z3.ExprRef],
+        coeff: np.ndarray, intercept: np.ndarray) -> Sequence[z3.ExprRef]:
+    assert len(coeff.shape) == 2  # Matrix is a 2D array
+    assert coeff.shape == (intercept.shape[0], len(exprs))
+
+    return [z3.Sum(*(coeff[i][j]*exprs[j] for j in range(coeff.shape[1])), intercept[i])
+            for i in range(coeff.shape[0])]
+
+def abs_expr(expr: z3.ExprRef) -> z3.ExprRef:
+    return z3.If(expr >= 0, expr, -expr)
+
+def l1_norm(*exprs) -> z3.ExprRef:
+    return z3.Sum(*(abs_expr(expr) for expr in exprs))
+
+def l2_norm(*exprs) -> z3.ExprRef:
+    return z3.Sum(*(expr**2 for expr in exprs))
+
+def max_expr(*exprs) -> z3.ExprRef:
+    m = exprs[0]
+    for v in exprs[1:]:
+        m = z3.If(m >= v, m, v)
+    return m
+
+def loo_norm(*exprs) -> z3.ExprRef:
+    return max_expr(*(abs_expr(expr) for expr in exprs))
+
+
+class FirstpassLearner(Z3LearnerBase):
+    def __init__(self) -> None:
+        super().__init__(state_dim=2, perc_dim=2)
+
+        self._in_shape_pred: z3.ExprRef = z3.BoolVal(True)
+        self.set_grammar(None)
+
+    @property
+    def num_shapes(self) -> int:
+        return len([l1_norm, l2_norm, loo_norm])
+
+    def set_grammar(self, grammar) -> None:
+        # TODO replace hardcoded grammar
+        self._in_shape_pred = self._in_shape_pred and self._get_shape_def(0)
+
+    def _get_shape_def(self, idx: int) -> z3.ExprRef:
+        shape_sel = z3.BoolVector("sel%d" % idx, self.num_shapes)
+
+        coeff_list = z3.Reals(["A%d__%d_%d" % (idx, i, j) for i in range(self.perc_dim) for j in range(self.state_dim)])
+        coeff_arr = np.array(coeff_list).reshape(self.perc_dim, self.state_dim)
+        intercept_arr = np.array(z3.RealVector("b%d" % idx, self.perc_dim))
+        radius = z3.Real("r%d" % idx)
+
+        transformed_perc_seq = affine_transform(self._state_vars, coeff_arr, intercept_arr)
+        return z3.And(
+            z3.AtLeast(*shape_sel, 1),
+            z3.Implies(shape_sel[0], l1_norm(*transformed_perc_seq)  <= radius),
+            z3.Implies(shape_sel[1], l2_norm(*transformed_perc_seq)  <= radius),
+            z3.Implies(shape_sel[2], loo_norm(*transformed_perc_seq) <= radius),
+        )
+        
+    def add_positive_examples(self, *vals) -> None:
+        assert all(len(val) == self.state_dim+self.perc_dim for val in vals)
+        
+        self._solver.add(*(
+            z3.substitute_vars(self._in_shape_pred, *val) for val in vals
+        ))
+
+    def add_negative_examples(self, *vals) -> None:
+        assert all(len(val) == self.state_dim+self.perc_dim for val in vals)
+        self._solver.add(*(
+            z3.Not(z3.substitute_vars(self._in_shape_pred, *val)) for val in vals
+        ))
+
+    def add_implication_examples(self, *args) -> None:
+        return super().add_implication_examples(*args)
+
+
+def test_affine_transform():
+    state_vars = z3.RealVector("x", 3)
+    percept_vars = z3.RealVector("z", 2)
+    coeff_list = z3.Reals(["A__%d_%d" % (i, j) for i in range(len(percept_vars)) for j in range(len(state_vars))])
+    coeff = np.array(coeff_list).reshape((len(percept_vars), len(state_vars)))
+    print(coeff)
+    intercept = np.array([7.0, 8.0])
+    ret = affine_transform(state_vars, coeff, intercept)
+    print(ret)
+
+
+def test_norms():
+    state_vars = z3.RealVarVector(3)
+    ret = l1_norm(*(v + float(i) for i, v in enumerate(state_vars)))
+    print(ret)
+    ret = l2_norm(*(v + float(i) for i, v in enumerate(state_vars)))
+    print(ret)
+    ret = loo_norm(*(v for i, v in enumerate(state_vars)))
+    print(ret)
+
+
+def test_firstpass_learner():
+    learner = FirstpassLearner()
+    vals = [z3.RealVal(0.0)]*learner.state_dim + [z3.RealVal(0.5)]*learner.perc_dim
+    learner.add_positive_examples(vals)
+    print(learner.learn())
+
+if __name__ == "__main__":
+    test_affine_transform()
+    test_norms()
+    test_firstpass_learner()
diff --git a/learner_base.py b/learner_base.py
index 4d8d100..2454451 100644
--- a/learner_base.py
+++ b/learner_base.py
@@ -1,4 +1,7 @@
 import abc
+from typing import Optional
+
+import z3
 
 class LearnerBase(abc.ABC):
     def __init__(self) -> None:
@@ -23,3 +26,26 @@ class LearnerBase(abc.ABC):
     @abc.abstractmethod
     def learn(self):
         raise NotImplementedError
+
+
+class Z3LearnerBase(LearnerBase):
+    def __init__(self, state_dim, perc_dim) -> None:
+        super().__init__()
+        self._solver = z3.SolverFor('QF_LRA')
+        self._state_vars = z3.RealVarVector(state_dim)
+        self._percept_vars = z3.RealVarVector(perc_dim)
+
+    @property
+    def state_dim(self) -> int:
+        return len(self._state_vars)
+
+    @property
+    def perc_dim(self) -> int:
+        return len(self._percept_vars)
+
+    def learn(self) -> Optional[z3.ModelRef]:
+        res = self._solver.check()
+        if res == z3.sat:
+            return self._solver.model()
+        else:
+            return None
-- 
GitLab