import re import gurobipy as gp import z3 from teacher_base import GurobiTeacherBase class DTreeGurobiTeacherBase(GurobiTeacherBase): PRECISION = 10**-3 Z3_VAR_RE = re.compile(r"(?P<var>\w+)_(?P<idx>\d+)") def _build_affine_expr(self, z3_expr: z3.ExprRef): if z3.is_rational_value(z3_expr): return z3_expr.as_fraction() elif z3.is_var(z3_expr) or z3.is_const(z3_expr): result = self.Z3_VAR_RE.search(str(z3_expr)) assert result is not None, str(z3_expr) var_name, idx = result.group("var"), result.group("idx") gp_var = self._gp_model.getVarByName(f"{var_name}[{idx}]") assert gp_var is not None return gp_var elif z3.is_add(z3_expr): return sum(self._build_affine_expr(arg) for arg in z3_expr.children()) elif z3.is_mul(z3_expr): if len(z3_expr.children()) > 2: raise NotImplementedError("TODO: multiplication of three or more operands") lhs = self._build_affine_expr(z3_expr.arg(0)) rhs = self._build_affine_expr(z3_expr.arg(1)) return lhs * rhs raise RuntimeError(f"Only support affine expressions. {z3_expr.children()}") def _set_candidate(self, conjunct: z3.BoolRef) -> None: # Variable Aliases m = self._gp_model # Remove contraints from previous candidate first m.remove(self._prev_candidate_constr) self._prev_candidate_constr.clear() m.update() conjunct = z3.simplify(conjunct, flat=True, arith_lhs=True) if z3.is_true(conjunct): return elif z3.is_and(conjunct): pred_list = list(conjunct.children()) elif z3.is_eq(conjunct) or z3.is_le(conjunct) or z3.is_ge(conjunct) or z3.is_not(conjunct): pred_list = [conjunct] else: raise RuntimeError(f"{conjunct} should be a conjunction.") for orig_pred in pred_list: if z3.is_not(orig_pred): pred = orig_pred.arg(0) else: pred = orig_pred assert z3.is_eq(pred) or z3.is_le(pred) or z3.is_ge(pred), str(pred) lhs = self._build_affine_expr(pred.arg(0)) rhs = self._build_affine_expr(pred.arg(1)) if z3.is_eq(pred): cons = (lhs == rhs) elif z3.is_ge(pred): if not z3.is_not(orig_pred): cons = (lhs >= rhs) else: # !(lhs >= rhs) <=> (lhs < rhs) => lhs <= rhs - ð cons = (lhs <= rhs - self.PRECISION) elif z3.is_le(pred): if not z3.is_not(orig_pred): cons = (lhs <= rhs) else: # !(lhs <= rhs) <=> (lhs > rhs) => lhs >= rhs + ð cons = (lhs >= rhs + self.PRECISION) else: raise RuntimeError(f"Unsupported atomic predicate expression {pred}") gp_cons = self._gp_model.addConstr(cons) self._prev_candidate_constr.append(gp_cons) def _candidate_to_conjuncts(self, candidate: z3.BoolRef): init_path = [candidate] stack = [init_path] while stack: curr_path = stack.pop() # remove this path curr_node = curr_path.pop() # remove last node in this path if z3.is_false(curr_node): # the leaf node in this path is false. Skip continue elif z3.is_true(curr_node): if not curr_path: yield z3.BoolVal(True) else: yield z3.And(*curr_path) elif z3.is_gt(curr_node) or z3.is_ge(curr_node) \ or z3.is_lt(curr_node) or z3.is_le(curr_node): yield z3.And(*curr_path, curr_node) elif z3.is_app_of(curr_node, z3.Z3_OP_ITE): cond, left, right = curr_node.children() l_path = curr_path.copy() l_path.extend([cond, left]) r_path = curr_path.copy() assert len(cond.children()) == 2 lhs, rhs = cond.children() if z3.is_le(cond): not_cond = lhs > rhs elif z3.is_ge(cond): not_cond = lhs < rhs else: raise RuntimeError(f"Unexpected condition {cond} for ITE") r_path.extend([not_cond, right]) stack.append(r_path) stack.append(l_path) else: raise RuntimeError(f"Candidate formula {curr_node} should have been converted to DNF.") def check(self, candidate: z3.BoolRef) -> z3.CheckSatResult: self._cexs.clear() conjunct_iter = self._candidate_to_conjuncts(candidate) #TODO ANGELLO: add check if too many conjunctions print(f"number of conjunct: {len(list(self._candidate_to_conjuncts(candidate)))}") print("Checking candidate", flush=True) for conjunct in conjunct_iter: # print(".", end='', flush=True) self._set_candidate(conjunct) self._gp_model.optimize() if self._gp_model.status == gp.GRB.INF_OR_UNBD: self._gp_model.setParam("DualReductions", 0) self._gp_model.optimize() if self._gp_model.status in [gp.GRB.OPTIMAL, gp.GRB.SUBOPTIMAL]: cex_list = [] for i in range(self._gp_model.SolCount): self._gp_model.Params.SolutionNumber = i cex = tuple(self._old_state.Xn) + tuple(self._percept.Xn) cex_list.append(cex) filtered_cex_list = [cex for cex in cex_list if not self.is_spurious_example(self.state_dim, self.perc_dim, conjunct, cex)] if filtered_cex_list: self._cexs.extend(filtered_cex_list) else: # raise RuntimeError(f"Only found spurious cexs {cex_list} for the conjunct {conjunct}.") pass elif self._gp_model.status == gp.GRB.INFEASIBLE: continue elif self._gp_model.status == gp.GRB.INTERRUPTED: raise KeyboardInterrupt else: return z3.unknown print("Done") if len(self._cexs) > 0: return z3.sat else: return z3.unsat