Skip to content
Snippets Groups Projects
Commit b4bcb5e6 authored by chsieh16's avatar chsieh16
Browse files

Add SMT encoding for learner of firstpass example

parent bb081b82
No related branches found
No related tags found
No related merge requests found
from typing import Sequence
import numpy as np
import z3
from learner_base import Z3LearnerBase
def affine_transform(exprs: Sequence[z3.ExprRef],
coeff: np.ndarray, intercept: np.ndarray) -> Sequence[z3.ExprRef]:
assert len(coeff.shape) == 2 # Matrix is a 2D array
assert coeff.shape == (intercept.shape[0], len(exprs))
return [z3.Sum(*(coeff[i][j]*exprs[j] for j in range(coeff.shape[1])), intercept[i])
for i in range(coeff.shape[0])]
def abs_expr(expr: z3.ExprRef) -> z3.ExprRef:
return z3.If(expr >= 0, expr, -expr)
def l1_norm(*exprs) -> z3.ExprRef:
return z3.Sum(*(abs_expr(expr) for expr in exprs))
def l2_norm(*exprs) -> z3.ExprRef:
return z3.Sum(*(expr**2 for expr in exprs))
def max_expr(*exprs) -> z3.ExprRef:
m = exprs[0]
for v in exprs[1:]:
m = z3.If(m >= v, m, v)
return m
def loo_norm(*exprs) -> z3.ExprRef:
return max_expr(*(abs_expr(expr) for expr in exprs))
class FirstpassLearner(Z3LearnerBase):
def __init__(self) -> None:
super().__init__(state_dim=2, perc_dim=2)
self._in_shape_pred: z3.ExprRef = z3.BoolVal(True)
self.set_grammar(None)
@property
def num_shapes(self) -> int:
return len([l1_norm, l2_norm, loo_norm])
def set_grammar(self, grammar) -> None:
# TODO replace hardcoded grammar
self._in_shape_pred = self._in_shape_pred and self._get_shape_def(0)
def _get_shape_def(self, idx: int) -> z3.ExprRef:
shape_sel = z3.BoolVector("sel%d" % idx, self.num_shapes)
coeff_list = z3.Reals(["A%d__%d_%d" % (idx, i, j) for i in range(self.perc_dim) for j in range(self.state_dim)])
coeff_arr = np.array(coeff_list).reshape(self.perc_dim, self.state_dim)
intercept_arr = np.array(z3.RealVector("b%d" % idx, self.perc_dim))
radius = z3.Real("r%d" % idx)
transformed_perc_seq = affine_transform(self._state_vars, coeff_arr, intercept_arr)
return z3.And(
z3.AtLeast(*shape_sel, 1),
z3.Implies(shape_sel[0], l1_norm(*transformed_perc_seq) <= radius),
z3.Implies(shape_sel[1], l2_norm(*transformed_perc_seq) <= radius),
z3.Implies(shape_sel[2], loo_norm(*transformed_perc_seq) <= radius),
)
def add_positive_examples(self, *vals) -> None:
assert all(len(val) == self.state_dim+self.perc_dim for val in vals)
self._solver.add(*(
z3.substitute_vars(self._in_shape_pred, *val) for val in vals
))
def add_negative_examples(self, *vals) -> None:
assert all(len(val) == self.state_dim+self.perc_dim for val in vals)
self._solver.add(*(
z3.Not(z3.substitute_vars(self._in_shape_pred, *val)) for val in vals
))
def add_implication_examples(self, *args) -> None:
return super().add_implication_examples(*args)
def test_affine_transform():
state_vars = z3.RealVector("x", 3)
percept_vars = z3.RealVector("z", 2)
coeff_list = z3.Reals(["A__%d_%d" % (i, j) for i in range(len(percept_vars)) for j in range(len(state_vars))])
coeff = np.array(coeff_list).reshape((len(percept_vars), len(state_vars)))
print(coeff)
intercept = np.array([7.0, 8.0])
ret = affine_transform(state_vars, coeff, intercept)
print(ret)
def test_norms():
state_vars = z3.RealVarVector(3)
ret = l1_norm(*(v + float(i) for i, v in enumerate(state_vars)))
print(ret)
ret = l2_norm(*(v + float(i) for i, v in enumerate(state_vars)))
print(ret)
ret = loo_norm(*(v for i, v in enumerate(state_vars)))
print(ret)
def test_firstpass_learner():
learner = FirstpassLearner()
vals = [z3.RealVal(0.0)]*learner.state_dim + [z3.RealVal(0.5)]*learner.perc_dim
learner.add_positive_examples(vals)
print(learner.learn())
if __name__ == "__main__":
test_affine_transform()
test_norms()
test_firstpass_learner()
import abc
from typing import Optional
import z3
class LearnerBase(abc.ABC):
def __init__(self) -> None:
......@@ -23,3 +26,26 @@ class LearnerBase(abc.ABC):
@abc.abstractmethod
def learn(self):
raise NotImplementedError
class Z3LearnerBase(LearnerBase):
def __init__(self, state_dim, perc_dim) -> None:
super().__init__()
self._solver = z3.SolverFor('QF_LRA')
self._state_vars = z3.RealVarVector(state_dim)
self._percept_vars = z3.RealVarVector(perc_dim)
@property
def state_dim(self) -> int:
return len(self._state_vars)
@property
def perc_dim(self) -> int:
return len(self._percept_vars)
def learn(self) -> Optional[z3.ModelRef]:
res = self._solver.check()
if res == z3.sat:
return self._solver.model()
else:
return None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment