Skip to content
Snippets Groups Projects
Commit 8cc8f2aa authored by chsieh16's avatar chsieh16
Browse files

Support selecting different shapes in both teacher and learner

parent 9cf78046
No related branches found
No related tags found
No related merge requests found
......@@ -18,13 +18,13 @@ def affine_transform(state_vars: Sequence[z3.ExprRef], perc_vars: Sequence[z3.Ex
def abs_expr(expr: z3.ExprRef) -> z3.ExprRef:
return z3.If(expr >= 0, expr, -expr)
def l1_norm(*exprs) -> z3.ExprRef:
def l1_norm(*exprs: z3.ExprRef) -> z3.ExprRef:
return z3.Sum(*(abs_expr(expr) for expr in exprs))
def l2_norm(*exprs) -> z3.ExprRef:
def l2_norm(*exprs: z3.ExprRef) -> z3.ExprRef:
return z3.Sum(*(expr*expr for expr in exprs))
def max_expr(*exprs) -> z3.ExprRef:
def max_expr(*exprs: z3.ExprRef) -> z3.ExprRef:
m = exprs[0]
for v in exprs[1:]:
m = z3.If(m >= v, m, v)
......@@ -58,15 +58,14 @@ class FirstpassLearner(Z3LearnerBase):
intercept_arr = np.array(z3.RealVector("b%d" % idx, self.perc_dim))
radius = z3.Real("r%d" % idx)
self._part_affine_map[idx] = (coeff_arr, intercept_arr, radius)
self._part_affine_map[idx] = (shape_sel, coeff_arr, intercept_arr, radius)
transformed_perc_seq = affine_transform(self._state_vars, self._percept_vars, coeff_arr, intercept_arr)
return z3.And(
l2_norm(*transformed_perc_seq) <= radius,
# z3.AtLeast(*shape_sel, 1),
# z3.Implies(shape_sel[0], l1_norm(*transformed_perc_seq) <= radius),
# z3.Implies(shape_sel[1], l2_norm(*transformed_perc_seq) <= radius),
# z3.Implies(shape_sel[2], loo_norm(*transformed_perc_seq) <= radius),
z3.AtLeast(*shape_sel, 1),
z3.Implies(shape_sel[0], l1_norm(*transformed_perc_seq) <= radius),
z3.Implies(shape_sel[1], l2_norm(*transformed_perc_seq) <= radius),
z3.Implies(shape_sel[2], loo_norm(*transformed_perc_seq) <= radius),
)
def add_positive_examples(self, *vals) -> None:
......@@ -89,12 +88,13 @@ class FirstpassLearner(Z3LearnerBase):
def learn(self) -> Optional[Tuple]:
res = self._solver.check()
if res == z3.sat:
z3_coeff_arr, z3_intercept_arr, z3_radius = self._part_affine_map[0] # FIXME
z3_shape_sel, z3_coeff_arr, z3_intercept_arr, z3_radius = self._part_affine_map[0] # FIXME
m = self._solver.model()
coeff_arr = np.vectorize(extract_from_z3_model(m))(z3_coeff_arr)
intercept_arr = np.vectorize(extract_from_z3_model(m))(z3_intercept_arr)
radius = extract_from_z3_model(m)(z3_radius)
return coeff_arr, intercept_arr, radius
shape_sel = extract_shape_from_z3_model(m, z3_shape_sel) # TODO: return integer is not that good
coeff_arr = np.vectorize(extract_real_from_z3_model(m))(z3_coeff_arr)
intercept_arr = np.vectorize(extract_real_from_z3_model(m))(z3_intercept_arr)
radius = extract_real_from_z3_model(m)(z3_radius)
return shape_sel, coeff_arr, intercept_arr, radius
elif res == z3.unsat:
print("(unsat)")
return None
......@@ -103,8 +103,14 @@ class FirstpassLearner(Z3LearnerBase):
return None
def extract_from_z3_model(m: z3.ModelRef):
def to_float(var):
def extract_shape_from_z3_model(m: z3.ModelRef, vars: Sequence[z3.ExprRef]) -> int:
for i, var in enumerate(vars):
if m[var]:
return i
def extract_real_from_z3_model(m: z3.ModelRef):
def to_float(var: z3.ExprRef) -> float:
x2 = m[var].as_fraction()
return float(x2.numerator) / float(x2.denominator)
return to_float
......@@ -145,7 +151,7 @@ def test_firstpass_learner():
(9.0, 9.0, 5.0, 5.0),
]
learner.add_negative_examples(*neg_examples)
print(learner._solver.assertions())
# print(learner._solver.assertions())
print(learner.learn())
if __name__ == "__main__":
......
......@@ -25,7 +25,7 @@ class FirstpassTeacher(GurobiTeacherBase):
assert len(old_state) == len(new_state)
m = self._gp_model
m.update()
old_V, new_V = 0, 0
old_V, new_V = 0.0, 0.0
for old_xi, new_xi in zip(old_state, new_state):
old_var_name = "|%s|" % old_xi.VarName
old_abs_xi = m.addVar(name=old_var_name,vtype=GRB.CONTINUOUS, lb=0, ub=np.inf)
......@@ -40,7 +40,7 @@ class FirstpassTeacher(GurobiTeacherBase):
def _add_candidate(self, candidate) -> None:
# TODO parse candidate
coeff, intercept, radius = candidate
shape_sel, coeff, intercept, radius = candidate
assert coeff.shape == (self._percept.shape[0], self._old_state.shape[0])
assert intercept.shape == self._percept.shape
......@@ -48,32 +48,56 @@ class FirstpassTeacher(GurobiTeacherBase):
x = self._old_state
z = self._percept
# Constraints on values of xP and yP
z_diff = m.addMVar(name="||z-(Ax+b)||", shape=len(z.tolist()), vtype=GRB.CONTINUOUS, lb=-np.inf, ub=np.inf)
z_diff = m.addMVar(name="z-(Ax+b)", shape=len(z.tolist()), vtype=GRB.CONTINUOUS, lb=-np.inf, ub=np.inf)
m.addConstr(z_diff == z - (coeff @ x + intercept))
m.addConstr(z_diff @ z_diff <= radius*radius, "z in M(x)")
m.update()
# L2-norm objective
m.setObjective(z_diff @ z_diff, GRB.MINIMIZE)
if shape_sel == 0:
z_dist = 0.0
for zi in z_diff.tolist():
abs_zi_name = "|%s|" % zi.VarName
abs_zi = m.addVar(name=abs_zi_name, vtype=GRB.CONTINUOUS, lb=0, ub=np.inf)
m.addConstr(abs_zi == gp.abs_(zi))
z_dist += abs_zi
m.addConstr(z_dist <= radius, "||z-(Ax+b)||_1 <= r")
# L1-norm objective
m.setObjective(z_dist, GRB.MINIMIZE)
elif shape_sel == 1:
m.addConstr(z_diff @ z_diff <= radius*radius, "||z-(Ax+b)||^2 <= r^2")
# L2-norm objective
m.setObjective(z_diff @ z_diff, GRB.MINIMIZE)
elif shape_sel == 2:
abs_zi_list = []
for zi in z_diff.tolist():
abs_zi_name = "|%s|" % zi.VarName
abs_zi = m.addVar(name=abs_zi_name, vtype=GRB.CONTINUOUS, lb=0, ub=np.inf)
m.addConstr(abs_zi == gp.abs_(zi))
abs_zi_list.append(abs_zi)
z_dist = m.addVar(name="||z-(Ax+b)||_oo", vtype=GRB.CONTINUOUS, lb=0, ub=np.inf)
m.addConstr(z_dist == gp.max_(abs_zi_list))
m.addConstr(z_dist <= radius, "||z-(Ax+b)||_oo <= r")
m.setObjective(z_dist, GRB.MINIMIZE)
CAND0_A = np.array([[0.0, -1.0],
[0.0, 1.0]])
CAND0_b = np.array([0.0, 0.0])
CAND0_r = 2.0
CAND0 = (CAND0_A, CAND0_b, CAND0_r)
CAND0 = (2, CAND0_A, CAND0_b, CAND0_r)
CAND1_A = np.array([[1.0, 0.0],
[0.0, 1.0]])
CAND1_b = np.array([0.0, 0.0])
CAND1_r = 2.0
CAND1 = (CAND1_A, CAND1_b, CAND1_r)
CAND1 = (1, CAND1_A, CAND1_b, CAND1_r)
def test_firstpass_teacher():
teacher = FirstpassTeacher()
teacher.set_old_state_bound(lb=[6.0, 5.0], ub=[11.0, 10.0])
teacher.dump_model()
result = teacher.check(CAND0)
teacher.dump_model()
print(result)
if result == z3.sat:
print(teacher.model())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment