Skip to content
Snippets Groups Projects
Commit 8cc8f2aa authored by chsieh16's avatar chsieh16
Browse files

Support selecting different shapes in both teacher and learner

parent 9cf78046
No related branches found
No related tags found
No related merge requests found
...@@ -18,13 +18,13 @@ def affine_transform(state_vars: Sequence[z3.ExprRef], perc_vars: Sequence[z3.Ex ...@@ -18,13 +18,13 @@ def affine_transform(state_vars: Sequence[z3.ExprRef], perc_vars: Sequence[z3.Ex
def abs_expr(expr: z3.ExprRef) -> z3.ExprRef: def abs_expr(expr: z3.ExprRef) -> z3.ExprRef:
return z3.If(expr >= 0, expr, -expr) return z3.If(expr >= 0, expr, -expr)
def l1_norm(*exprs) -> z3.ExprRef: def l1_norm(*exprs: z3.ExprRef) -> z3.ExprRef:
return z3.Sum(*(abs_expr(expr) for expr in exprs)) return z3.Sum(*(abs_expr(expr) for expr in exprs))
def l2_norm(*exprs) -> z3.ExprRef: def l2_norm(*exprs: z3.ExprRef) -> z3.ExprRef:
return z3.Sum(*(expr*expr for expr in exprs)) return z3.Sum(*(expr*expr for expr in exprs))
def max_expr(*exprs) -> z3.ExprRef: def max_expr(*exprs: z3.ExprRef) -> z3.ExprRef:
m = exprs[0] m = exprs[0]
for v in exprs[1:]: for v in exprs[1:]:
m = z3.If(m >= v, m, v) m = z3.If(m >= v, m, v)
...@@ -58,15 +58,14 @@ class FirstpassLearner(Z3LearnerBase): ...@@ -58,15 +58,14 @@ class FirstpassLearner(Z3LearnerBase):
intercept_arr = np.array(z3.RealVector("b%d" % idx, self.perc_dim)) intercept_arr = np.array(z3.RealVector("b%d" % idx, self.perc_dim))
radius = z3.Real("r%d" % idx) radius = z3.Real("r%d" % idx)
self._part_affine_map[idx] = (coeff_arr, intercept_arr, radius) self._part_affine_map[idx] = (shape_sel, coeff_arr, intercept_arr, radius)
transformed_perc_seq = affine_transform(self._state_vars, self._percept_vars, coeff_arr, intercept_arr) transformed_perc_seq = affine_transform(self._state_vars, self._percept_vars, coeff_arr, intercept_arr)
return z3.And( return z3.And(
l2_norm(*transformed_perc_seq) <= radius, z3.AtLeast(*shape_sel, 1),
# z3.AtLeast(*shape_sel, 1), z3.Implies(shape_sel[0], l1_norm(*transformed_perc_seq) <= radius),
# z3.Implies(shape_sel[0], l1_norm(*transformed_perc_seq) <= radius), z3.Implies(shape_sel[1], l2_norm(*transformed_perc_seq) <= radius),
# z3.Implies(shape_sel[1], l2_norm(*transformed_perc_seq) <= radius), z3.Implies(shape_sel[2], loo_norm(*transformed_perc_seq) <= radius),
# z3.Implies(shape_sel[2], loo_norm(*transformed_perc_seq) <= radius),
) )
def add_positive_examples(self, *vals) -> None: def add_positive_examples(self, *vals) -> None:
...@@ -89,12 +88,13 @@ class FirstpassLearner(Z3LearnerBase): ...@@ -89,12 +88,13 @@ class FirstpassLearner(Z3LearnerBase):
def learn(self) -> Optional[Tuple]: def learn(self) -> Optional[Tuple]:
res = self._solver.check() res = self._solver.check()
if res == z3.sat: if res == z3.sat:
z3_coeff_arr, z3_intercept_arr, z3_radius = self._part_affine_map[0] # FIXME z3_shape_sel, z3_coeff_arr, z3_intercept_arr, z3_radius = self._part_affine_map[0] # FIXME
m = self._solver.model() m = self._solver.model()
coeff_arr = np.vectorize(extract_from_z3_model(m))(z3_coeff_arr) shape_sel = extract_shape_from_z3_model(m, z3_shape_sel) # TODO: return integer is not that good
intercept_arr = np.vectorize(extract_from_z3_model(m))(z3_intercept_arr) coeff_arr = np.vectorize(extract_real_from_z3_model(m))(z3_coeff_arr)
radius = extract_from_z3_model(m)(z3_radius) intercept_arr = np.vectorize(extract_real_from_z3_model(m))(z3_intercept_arr)
return coeff_arr, intercept_arr, radius radius = extract_real_from_z3_model(m)(z3_radius)
return shape_sel, coeff_arr, intercept_arr, radius
elif res == z3.unsat: elif res == z3.unsat:
print("(unsat)") print("(unsat)")
return None return None
...@@ -103,8 +103,14 @@ class FirstpassLearner(Z3LearnerBase): ...@@ -103,8 +103,14 @@ class FirstpassLearner(Z3LearnerBase):
return None return None
def extract_from_z3_model(m: z3.ModelRef): def extract_shape_from_z3_model(m: z3.ModelRef, vars: Sequence[z3.ExprRef]) -> int:
def to_float(var): for i, var in enumerate(vars):
if m[var]:
return i
def extract_real_from_z3_model(m: z3.ModelRef):
def to_float(var: z3.ExprRef) -> float:
x2 = m[var].as_fraction() x2 = m[var].as_fraction()
return float(x2.numerator) / float(x2.denominator) return float(x2.numerator) / float(x2.denominator)
return to_float return to_float
...@@ -145,7 +151,7 @@ def test_firstpass_learner(): ...@@ -145,7 +151,7 @@ def test_firstpass_learner():
(9.0, 9.0, 5.0, 5.0), (9.0, 9.0, 5.0, 5.0),
] ]
learner.add_negative_examples(*neg_examples) learner.add_negative_examples(*neg_examples)
print(learner._solver.assertions()) # print(learner._solver.assertions())
print(learner.learn()) print(learner.learn())
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -25,7 +25,7 @@ class FirstpassTeacher(GurobiTeacherBase): ...@@ -25,7 +25,7 @@ class FirstpassTeacher(GurobiTeacherBase):
assert len(old_state) == len(new_state) assert len(old_state) == len(new_state)
m = self._gp_model m = self._gp_model
m.update() m.update()
old_V, new_V = 0, 0 old_V, new_V = 0.0, 0.0
for old_xi, new_xi in zip(old_state, new_state): for old_xi, new_xi in zip(old_state, new_state):
old_var_name = "|%s|" % old_xi.VarName old_var_name = "|%s|" % old_xi.VarName
old_abs_xi = m.addVar(name=old_var_name,vtype=GRB.CONTINUOUS, lb=0, ub=np.inf) old_abs_xi = m.addVar(name=old_var_name,vtype=GRB.CONTINUOUS, lb=0, ub=np.inf)
...@@ -40,7 +40,7 @@ class FirstpassTeacher(GurobiTeacherBase): ...@@ -40,7 +40,7 @@ class FirstpassTeacher(GurobiTeacherBase):
def _add_candidate(self, candidate) -> None: def _add_candidate(self, candidate) -> None:
# TODO parse candidate # TODO parse candidate
coeff, intercept, radius = candidate shape_sel, coeff, intercept, radius = candidate
assert coeff.shape == (self._percept.shape[0], self._old_state.shape[0]) assert coeff.shape == (self._percept.shape[0], self._old_state.shape[0])
assert intercept.shape == self._percept.shape assert intercept.shape == self._percept.shape
...@@ -48,32 +48,56 @@ class FirstpassTeacher(GurobiTeacherBase): ...@@ -48,32 +48,56 @@ class FirstpassTeacher(GurobiTeacherBase):
x = self._old_state x = self._old_state
z = self._percept z = self._percept
# Constraints on values of xP and yP # Constraints on values of xP and yP
z_diff = m.addMVar(name="||z-(Ax+b)||", shape=len(z.tolist()), vtype=GRB.CONTINUOUS, lb=-np.inf, ub=np.inf) z_diff = m.addMVar(name="z-(Ax+b)", shape=len(z.tolist()), vtype=GRB.CONTINUOUS, lb=-np.inf, ub=np.inf)
m.addConstr(z_diff == z - (coeff @ x + intercept)) m.addConstr(z_diff == z - (coeff @ x + intercept))
m.addConstr(z_diff @ z_diff <= radius*radius, "z in M(x)") m.update()
# L2-norm objective if shape_sel == 0:
m.setObjective(z_diff @ z_diff, GRB.MINIMIZE) z_dist = 0.0
for zi in z_diff.tolist():
abs_zi_name = "|%s|" % zi.VarName
abs_zi = m.addVar(name=abs_zi_name, vtype=GRB.CONTINUOUS, lb=0, ub=np.inf)
m.addConstr(abs_zi == gp.abs_(zi))
z_dist += abs_zi
m.addConstr(z_dist <= radius, "||z-(Ax+b)||_1 <= r")
# L1-norm objective
m.setObjective(z_dist, GRB.MINIMIZE)
elif shape_sel == 1:
m.addConstr(z_diff @ z_diff <= radius*radius, "||z-(Ax+b)||^2 <= r^2")
# L2-norm objective
m.setObjective(z_diff @ z_diff, GRB.MINIMIZE)
elif shape_sel == 2:
abs_zi_list = []
for zi in z_diff.tolist():
abs_zi_name = "|%s|" % zi.VarName
abs_zi = m.addVar(name=abs_zi_name, vtype=GRB.CONTINUOUS, lb=0, ub=np.inf)
m.addConstr(abs_zi == gp.abs_(zi))
abs_zi_list.append(abs_zi)
z_dist = m.addVar(name="||z-(Ax+b)||_oo", vtype=GRB.CONTINUOUS, lb=0, ub=np.inf)
m.addConstr(z_dist == gp.max_(abs_zi_list))
m.addConstr(z_dist <= radius, "||z-(Ax+b)||_oo <= r")
m.setObjective(z_dist, GRB.MINIMIZE)
CAND0_A = np.array([[0.0, -1.0], CAND0_A = np.array([[0.0, -1.0],
[0.0, 1.0]]) [0.0, 1.0]])
CAND0_b = np.array([0.0, 0.0]) CAND0_b = np.array([0.0, 0.0])
CAND0_r = 2.0 CAND0_r = 2.0
CAND0 = (CAND0_A, CAND0_b, CAND0_r) CAND0 = (2, CAND0_A, CAND0_b, CAND0_r)
CAND1_A = np.array([[1.0, 0.0], CAND1_A = np.array([[1.0, 0.0],
[0.0, 1.0]]) [0.0, 1.0]])
CAND1_b = np.array([0.0, 0.0]) CAND1_b = np.array([0.0, 0.0])
CAND1_r = 2.0 CAND1_r = 2.0
CAND1 = (CAND1_A, CAND1_b, CAND1_r) CAND1 = (1, CAND1_A, CAND1_b, CAND1_r)
def test_firstpass_teacher(): def test_firstpass_teacher():
teacher = FirstpassTeacher() teacher = FirstpassTeacher()
teacher.set_old_state_bound(lb=[6.0, 5.0], ub=[11.0, 10.0]) teacher.set_old_state_bound(lb=[6.0, 5.0], ub=[11.0, 10.0])
teacher.dump_model()
result = teacher.check(CAND0) result = teacher.check(CAND0)
teacher.dump_model()
print(result) print(result)
if result == z3.sat: if result == z3.sat:
print(teacher.model()) print(teacher.model())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment