Newer
Older
Z3_VAR_RE = re.compile(r"(?P<var>\w+)_(?P<idx>\d+)")
def _build_affine_expr(self, z3_expr: z3.ExprRef):
if z3.is_rational_value(z3_expr):
return z3_expr.as_fraction()
elif z3.is_var(z3_expr) or z3.is_const(z3_expr):
result = self.Z3_VAR_RE.search(str(z3_expr))
assert result is not None, str(z3_expr)
var_name, idx = result.group("var"), result.group("idx")
gp_var = self._gp_model.getVarByName(f"{var_name}[{idx}]")
assert gp_var is not None
return gp_var
elif z3.is_add(z3_expr):
return sum(self._build_affine_expr(arg) for arg in z3_expr.children())
elif z3.is_mul(z3_expr):
if len(z3_expr.children()) > 2:
raise NotImplementedError("TODO: multiplication of three or more operands")
lhs = self._build_affine_expr(z3_expr.arg(0))
rhs = self._build_affine_expr(z3_expr.arg(1))
raise RuntimeError(f"Only support affine expressions. {z3_expr.children()}")
def _set_candidate(self, conjunct: z3.BoolRef) -> None:
# Variable Aliases
m = self._gp_model
# Remove contraints from previous candidate first
m.remove(self._prev_candidate_constr)
self._prev_candidate_constr.clear()
m.update()
conjunct = z3.simplify(conjunct, flat=True, arith_lhs=True)
if z3.is_true(conjunct):
elif z3.is_and(conjunct):
pred_list = list(conjunct.children())
elif z3.is_eq(conjunct) or z3.is_le(conjunct) or z3.is_ge(conjunct) or z3.is_not(conjunct):
pred_list = [conjunct]
else:
raise RuntimeError(f"{conjunct} should be a conjunction.")
for orig_pred in pred_list:
if z3.is_not(orig_pred):
pred = orig_pred.arg(0)
else:
pred = orig_pred
assert z3.is_eq(pred) or z3.is_le(pred) or z3.is_ge(pred), str(pred)
lhs = self._build_affine_expr(pred.arg(0))
rhs = self._build_affine_expr(pred.arg(1))
if z3.is_eq(pred):
elif z3.is_ge(pred):
if not z3.is_not(orig_pred):
cons = (lhs >= rhs)
else: # !(lhs >= rhs) <=> (lhs < rhs) => lhs <= rhs - ð
cons = (lhs <= rhs - self.PRECISION)
elif z3.is_le(pred):
if not z3.is_not(orig_pred):
cons = (lhs <= rhs)
else: # !(lhs <= rhs) <=> (lhs > rhs) => lhs >= rhs + ð
cons = (lhs >= rhs + self.PRECISION)
raise RuntimeError(f"Unsupported atomic predicate expression {pred}")
gp_cons = self._gp_model.addConstr(cons)
self._prev_candidate_constr.append(gp_cons)
def _candidate_to_conjuncts(self, candidate: z3.BoolRef):
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
init_path = [candidate]
stack = [init_path]
while stack:
curr_path = stack.pop() # remove this path
curr_node = curr_path.pop() # remove last node in this path
if z3.is_false(curr_node):
# the leaf node in this path is false. Skip
continue
elif z3.is_true(curr_node):
if not curr_path:
yield z3.BoolVal(True)
else:
yield z3.And(*curr_path)
elif z3.is_gt(curr_node) or z3.is_ge(curr_node) \
or z3.is_lt(curr_node) or z3.is_le(curr_node):
yield z3.And(*curr_path, curr_node)
elif z3.is_app_of(curr_node, z3.Z3_OP_ITE):
cond, left, right = curr_node.children()
l_path = curr_path.copy()
l_path.extend([cond, left])
r_path = curr_path.copy()
assert len(cond.children()) == 2
lhs, rhs = cond.children()
if z3.is_le(cond):
not_cond = lhs > rhs
elif z3.is_ge(cond):
not_cond = lhs < rhs
else:
raise RuntimeError(f"Unexpected condition {cond} for ITE")
r_path.extend([not_cond, right])
stack.append(r_path)
stack.append(l_path)
else:
raise RuntimeError(f"Candidate formula {curr_node} should have been converted to DNF.")
def check(self, candidate: z3.BoolRef) -> z3.CheckSatResult:
conjunct_iter = self._candidate_to_conjuncts(candidate)
#TODO ANGELLO: add check if too many conjunctions
print(f"number of conjunct: {len(list(self._candidate_to_conjuncts(candidate)))}")
print("Checking candidate", flush=True)
for conjunct in conjunct_iter:
# print(".", end='', flush=True)
self._set_candidate(conjunct)
self._gp_model.optimize()
if self._gp_model.status == gp.GRB.INF_OR_UNBD:
self._gp_model.setParam("DualReductions", 0)
self._gp_model.optimize()
if self._gp_model.status in [gp.GRB.OPTIMAL, gp.GRB.SUBOPTIMAL]:
cex_list = []
for i in range(self._gp_model.SolCount):
self._gp_model.Params.SolutionNumber = i
cex = tuple(self._old_state.Xn) + tuple(self._percept.Xn)
cex_list.append(cex)
filtered_cex_list = [cex for cex in cex_list
if not self.is_spurious_example(self.state_dim, self.perc_dim, conjunct, cex)]
if filtered_cex_list:
self._cexs.extend(filtered_cex_list)
else:
# raise RuntimeError(f"Only found spurious cexs {cex_list} for the conjunct {conjunct}.")
elif self._gp_model.status == gp.GRB.INFEASIBLE:
continue
elif self._gp_model.status == gp.GRB.INTERRUPTED:
raise KeyboardInterrupt