Skip to content
Snippets Groups Projects
Commit b8382e5c authored by chsieh16's avatar chsieh16
Browse files

Use only circle for shape

parent b4bcb5e6
No related branches found
No related tags found
No related merge requests found
......@@ -6,12 +6,13 @@ import z3
from learner_base import Z3LearnerBase
def affine_transform(exprs: Sequence[z3.ExprRef],
def affine_transform(state_vars: Sequence[z3.ExprRef], perc_vars: Sequence[z3.ExprRef],
coeff: np.ndarray, intercept: np.ndarray) -> Sequence[z3.ExprRef]:
assert len(coeff.shape) == 2 # Matrix is a 2D array
assert coeff.shape == (intercept.shape[0], len(exprs))
assert coeff.shape == (intercept.shape[0], len(state_vars))
assert intercept.shape[0] == len(perc_vars)
return [z3.Sum(*(coeff[i][j]*exprs[j] for j in range(coeff.shape[1])), intercept[i])
return [perc_vars[i] - z3.Sum(*(coeff[i][j]*state_vars[j] for j in range(coeff.shape[1])), intercept[i])
for i in range(coeff.shape[0])]
def abs_expr(expr: z3.ExprRef) -> z3.ExprRef:
......@@ -56,25 +57,27 @@ class FirstpassLearner(Z3LearnerBase):
intercept_arr = np.array(z3.RealVector("b%d" % idx, self.perc_dim))
radius = z3.Real("r%d" % idx)
transformed_perc_seq = affine_transform(self._state_vars, coeff_arr, intercept_arr)
transformed_perc_seq = affine_transform(self._state_vars, self._percept_vars, coeff_arr, intercept_arr)
return z3.And(
z3.AtLeast(*shape_sel, 1),
z3.Implies(shape_sel[0], l1_norm(*transformed_perc_seq) <= radius),
z3.Implies(shape_sel[1], l2_norm(*transformed_perc_seq) <= radius),
z3.Implies(shape_sel[2], loo_norm(*transformed_perc_seq) <= radius),
l2_norm(*transformed_perc_seq) <= radius,
# z3.AtLeast(*shape_sel, 1),
# z3.Implies(shape_sel[0], l1_norm(*transformed_perc_seq) <= radius),
# z3.Implies(shape_sel[1], l2_norm(*transformed_perc_seq) <= radius),
# z3.Implies(shape_sel[2], loo_norm(*transformed_perc_seq) <= radius),
)
def add_positive_examples(self, *vals) -> None:
assert all(len(val) == self.state_dim+self.perc_dim for val in vals)
self._solver.add(*(
z3.substitute_vars(self._in_shape_pred, *val) for val in vals
z3.substitute_vars(self._in_shape_pred,
*(z3.RealVal(v) for v in val)) for val in vals
))
def add_negative_examples(self, *vals) -> None:
assert all(len(val) == self.state_dim+self.perc_dim for val in vals)
self._solver.add(*(
z3.Not(z3.substitute_vars(self._in_shape_pred, *val)) for val in vals
z3.Not(z3.substitute_vars(self._in_shape_pred,
*(z3.RealVal(v) for v in val))) for val in vals
))
def add_implication_examples(self, *args) -> None:
......@@ -104,11 +107,22 @@ def test_norms():
def test_firstpass_learner():
learner = FirstpassLearner()
vals = [z3.RealVal(0.0)]*learner.state_dim + [z3.RealVal(0.5)]*learner.perc_dim
learner.add_positive_examples(vals)
pos_examples = [
(10.0, -10.0, 11.0, -11.0),
(10.0, -10.0, 9.0, -9.0),
(1.0, 1.0, 1.5, 2.3),
]
learner.add_positive_examples(*pos_examples)
neg_examples = [
(1.0, 1.0, 0.5, 0.5),
(1.0, 1.0, 1.5, 1.5),
(9.0, 9.0, 5.0, 5.0),
]
learner.add_negative_examples(*neg_examples)
print(learner._solver.assertions())
print(learner.learn())
if __name__ == "__main__":
test_affine_transform()
test_norms()
# test_affine_transform()
# test_norms()
test_firstpass_learner()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment