From b4bcb5e62a4108c92355f1261b4f4387321a7c25 Mon Sep 17 00:00:00 2001 From: "Hsieh, Chiao" <chsieh16@illinois.edu> Date: Mon, 29 Nov 2021 15:38:53 -0600 Subject: [PATCH] Add SMT encoding for learner of firstpass example --- firstpass_learner.py | 114 +++++++++++++++++++++++++++++++++++++++++++ learner_base.py | 26 ++++++++++ 2 files changed, 140 insertions(+) create mode 100644 firstpass_learner.py diff --git a/firstpass_learner.py b/firstpass_learner.py new file mode 100644 index 0000000..b0c5679 --- /dev/null +++ b/firstpass_learner.py @@ -0,0 +1,114 @@ +from typing import Sequence + +import numpy as np +import z3 + +from learner_base import Z3LearnerBase + + +def affine_transform(exprs: Sequence[z3.ExprRef], + coeff: np.ndarray, intercept: np.ndarray) -> Sequence[z3.ExprRef]: + assert len(coeff.shape) == 2 # Matrix is a 2D array + assert coeff.shape == (intercept.shape[0], len(exprs)) + + return [z3.Sum(*(coeff[i][j]*exprs[j] for j in range(coeff.shape[1])), intercept[i]) + for i in range(coeff.shape[0])] + +def abs_expr(expr: z3.ExprRef) -> z3.ExprRef: + return z3.If(expr >= 0, expr, -expr) + +def l1_norm(*exprs) -> z3.ExprRef: + return z3.Sum(*(abs_expr(expr) for expr in exprs)) + +def l2_norm(*exprs) -> z3.ExprRef: + return z3.Sum(*(expr**2 for expr in exprs)) + +def max_expr(*exprs) -> z3.ExprRef: + m = exprs[0] + for v in exprs[1:]: + m = z3.If(m >= v, m, v) + return m + +def loo_norm(*exprs) -> z3.ExprRef: + return max_expr(*(abs_expr(expr) for expr in exprs)) + + +class FirstpassLearner(Z3LearnerBase): + def __init__(self) -> None: + super().__init__(state_dim=2, perc_dim=2) + + self._in_shape_pred: z3.ExprRef = z3.BoolVal(True) + self.set_grammar(None) + + @property + def num_shapes(self) -> int: + return len([l1_norm, l2_norm, loo_norm]) + + def set_grammar(self, grammar) -> None: + # TODO replace hardcoded grammar + self._in_shape_pred = self._in_shape_pred and self._get_shape_def(0) + + def _get_shape_def(self, idx: int) -> z3.ExprRef: + shape_sel = z3.BoolVector("sel%d" % idx, self.num_shapes) + + coeff_list = z3.Reals(["A%d__%d_%d" % (idx, i, j) for i in range(self.perc_dim) for j in range(self.state_dim)]) + coeff_arr = np.array(coeff_list).reshape(self.perc_dim, self.state_dim) + intercept_arr = np.array(z3.RealVector("b%d" % idx, self.perc_dim)) + radius = z3.Real("r%d" % idx) + + transformed_perc_seq = affine_transform(self._state_vars, coeff_arr, intercept_arr) + return z3.And( + z3.AtLeast(*shape_sel, 1), + z3.Implies(shape_sel[0], l1_norm(*transformed_perc_seq) <= radius), + z3.Implies(shape_sel[1], l2_norm(*transformed_perc_seq) <= radius), + z3.Implies(shape_sel[2], loo_norm(*transformed_perc_seq) <= radius), + ) + + def add_positive_examples(self, *vals) -> None: + assert all(len(val) == self.state_dim+self.perc_dim for val in vals) + + self._solver.add(*( + z3.substitute_vars(self._in_shape_pred, *val) for val in vals + )) + + def add_negative_examples(self, *vals) -> None: + assert all(len(val) == self.state_dim+self.perc_dim for val in vals) + self._solver.add(*( + z3.Not(z3.substitute_vars(self._in_shape_pred, *val)) for val in vals + )) + + def add_implication_examples(self, *args) -> None: + return super().add_implication_examples(*args) + + +def test_affine_transform(): + state_vars = z3.RealVector("x", 3) + percept_vars = z3.RealVector("z", 2) + coeff_list = z3.Reals(["A__%d_%d" % (i, j) for i in range(len(percept_vars)) for j in range(len(state_vars))]) + coeff = np.array(coeff_list).reshape((len(percept_vars), len(state_vars))) + print(coeff) + intercept = np.array([7.0, 8.0]) + ret = affine_transform(state_vars, coeff, intercept) + print(ret) + + +def test_norms(): + state_vars = z3.RealVarVector(3) + ret = l1_norm(*(v + float(i) for i, v in enumerate(state_vars))) + print(ret) + ret = l2_norm(*(v + float(i) for i, v in enumerate(state_vars))) + print(ret) + ret = loo_norm(*(v for i, v in enumerate(state_vars))) + print(ret) + + +def test_firstpass_learner(): + learner = FirstpassLearner() + vals = [z3.RealVal(0.0)]*learner.state_dim + [z3.RealVal(0.5)]*learner.perc_dim + learner.add_positive_examples(vals) + print(learner.learn()) + +if __name__ == "__main__": + test_affine_transform() + test_norms() + test_firstpass_learner() diff --git a/learner_base.py b/learner_base.py index 4d8d100..2454451 100644 --- a/learner_base.py +++ b/learner_base.py @@ -1,4 +1,7 @@ import abc +from typing import Optional + +import z3 class LearnerBase(abc.ABC): def __init__(self) -> None: @@ -23,3 +26,26 @@ class LearnerBase(abc.ABC): @abc.abstractmethod def learn(self): raise NotImplementedError + + +class Z3LearnerBase(LearnerBase): + def __init__(self, state_dim, perc_dim) -> None: + super().__init__() + self._solver = z3.SolverFor('QF_LRA') + self._state_vars = z3.RealVarVector(state_dim) + self._percept_vars = z3.RealVarVector(perc_dim) + + @property + def state_dim(self) -> int: + return len(self._state_vars) + + @property + def perc_dim(self) -> int: + return len(self._percept_vars) + + def learn(self) -> Optional[z3.ModelRef]: + res = self._solver.check() + if res == z3.sat: + return self._solver.model() + else: + return None -- GitLab