Skip to content
Snippets Groups Projects
dtree_teacher_base.py 6.35 KiB
Newer Older
  • Learn to ignore specific revisions
  • 
    import gurobipy as gp
    
    chsieh16's avatar
    chsieh16 committed
    from teacher_base import GurobiTeacherBase
    
    chsieh16's avatar
    chsieh16 committed
    class DTreeGurobiTeacherBase(GurobiTeacherBase):
    
        PRECISION = 10**-3
    
        Z3_VAR_RE = re.compile(r"(?P<var>\w+)_(?P<idx>\d+)")
    
        def _build_affine_expr(self, z3_expr: z3.ExprRef):
            if z3.is_rational_value(z3_expr):
                return z3_expr.as_fraction()
            elif z3.is_var(z3_expr) or z3.is_const(z3_expr):
                result = self.Z3_VAR_RE.search(str(z3_expr))
                assert result is not None, str(z3_expr)
    
                var_name, idx = result.group("var"), result.group("idx")
                gp_var = self._gp_model.getVarByName(f"{var_name}[{idx}]")
                assert gp_var is not None
                return gp_var
    
            elif z3.is_add(z3_expr):
                return sum(self._build_affine_expr(arg) for arg in z3_expr.children())
            elif z3.is_mul(z3_expr):
                if len(z3_expr.children()) > 2:
    
                    raise NotImplementedError("TODO: multiplication of three or more operands")
    
                lhs = self._build_affine_expr(z3_expr.arg(0))
                rhs = self._build_affine_expr(z3_expr.arg(1))
    
                return lhs * rhs
    
            raise RuntimeError(f"Only support affine expressions. {z3_expr.children()}")
    
        def _set_candidate(self, conjunct: z3.BoolRef) -> None:
    
            # Variable Aliases
            m = self._gp_model
    
            # Remove contraints from previous candidate first
            m.remove(self._prev_candidate_constr)
            self._prev_candidate_constr.clear()
            m.update()
    
    
            conjunct = z3.simplify(conjunct, flat=True, arith_lhs=True)
            if z3.is_true(conjunct):
    
            elif z3.is_and(conjunct):
                pred_list = list(conjunct.children())
            elif z3.is_eq(conjunct) or z3.is_le(conjunct) or z3.is_ge(conjunct) or z3.is_not(conjunct):
    
                pred_list = [conjunct]
            else:
                raise RuntimeError(f"{conjunct} should be a conjunction.")
    
    
            for orig_pred in pred_list:
                if z3.is_not(orig_pred):
                    pred = orig_pred.arg(0)
                else:
                    pred = orig_pred
    
                assert z3.is_eq(pred) or z3.is_le(pred) or z3.is_ge(pred), str(pred)
                lhs = self._build_affine_expr(pred.arg(0))
                rhs = self._build_affine_expr(pred.arg(1))
    
                if z3.is_eq(pred):
    
                    cons = (lhs == rhs)
    
                elif z3.is_ge(pred):
                    if not z3.is_not(orig_pred):
                        cons = (lhs >= rhs)
                    else:  # !(lhs >= rhs) <=> (lhs < rhs) => lhs <= rhs - ð
                        cons = (lhs <= rhs - self.PRECISION)
                elif z3.is_le(pred):
                    if not z3.is_not(orig_pred):
                        cons = (lhs <= rhs)
                    else:  # !(lhs <= rhs) <=> (lhs > rhs) => lhs >= rhs + ð
                        cons = (lhs >= rhs + self.PRECISION)
    
                    raise RuntimeError(f"Unsupported atomic predicate expression {pred}")
    
    
                gp_cons = self._gp_model.addConstr(cons)
                self._prev_candidate_constr.append(gp_cons)
    
    
    chsieh16's avatar
    chsieh16 committed
        def _candidate_to_conjuncts(self, candidate: z3.BoolRef):
    
            init_path = [candidate]
            stack = [init_path]
            while stack:
                curr_path = stack.pop()  # remove this path
                curr_node = curr_path.pop()  # remove last node in this path
                if z3.is_false(curr_node):
                    # the leaf node in this path is false. Skip
                    continue
                elif z3.is_true(curr_node):
                    if not curr_path:
                        yield z3.BoolVal(True)
                    else:
                        yield z3.And(*curr_path)
                elif z3.is_gt(curr_node) or z3.is_ge(curr_node) \
                        or z3.is_lt(curr_node) or z3.is_le(curr_node):
                    yield z3.And(*curr_path, curr_node)
                elif z3.is_app_of(curr_node, z3.Z3_OP_ITE):
                    cond, left, right = curr_node.children()
                    l_path = curr_path.copy()
                    l_path.extend([cond, left])
    
                    r_path = curr_path.copy()
                    assert len(cond.children()) == 2
                    lhs, rhs = cond.children()
                    if z3.is_le(cond):
                        not_cond = lhs > rhs
                    elif z3.is_ge(cond):
                        not_cond = lhs < rhs
                    else:
                        raise RuntimeError(f"Unexpected condition {cond} for ITE")
                    r_path.extend([not_cond, right])
    
                    stack.append(r_path)
                    stack.append(l_path)
                else:
                    raise RuntimeError(f"Candidate formula {curr_node} should have been converted to DNF.")
    
        def check(self, candidate: z3.BoolRef) -> z3.CheckSatResult:
    
            self._cexs.clear()
    
    chsieh16's avatar
    chsieh16 committed
            conjunct_iter = self._candidate_to_conjuncts(candidate)
    
            #TODO ANGELLO: add check if too many conjunctions
    
            print(f"number of conjunct:  {len(list(self._candidate_to_conjuncts(candidate)))}")
    
            print("Checking candidate", flush=True)
            for conjunct in conjunct_iter:
                # print(".", end='', flush=True)
    
                self._set_candidate(conjunct)
                self._gp_model.optimize()
                if self._gp_model.status == gp.GRB.INF_OR_UNBD:
                    self._gp_model.setParam("DualReductions", 0)
                    self._gp_model.optimize()
    
                if self._gp_model.status in [gp.GRB.OPTIMAL, gp.GRB.SUBOPTIMAL]:
    
                    cex_list = []
                    for i in range(self._gp_model.SolCount):
                        self._gp_model.Params.SolutionNumber = i
                        cex = tuple(self._old_state.Xn) + tuple(self._percept.Xn)
                        cex_list.append(cex)
                    filtered_cex_list = [cex for cex in cex_list
                                         if not self.is_spurious_example(self.state_dim, self.perc_dim, conjunct, cex)]
                    if filtered_cex_list:
                        self._cexs.extend(filtered_cex_list)
                    else:
    
    chsieh16's avatar
    chsieh16 committed
                        # raise RuntimeError(f"Only found spurious cexs {cex_list} for the conjunct {conjunct}.")
    
                        pass
    
                elif self._gp_model.status == gp.GRB.INFEASIBLE:
                    continue
    
                elif self._gp_model.status == gp.GRB.INTERRUPTED:
                    raise KeyboardInterrupt
    
                else:
                    return z3.unknown
    
            if len(self._cexs) > 0:
                return z3.sat
            else:
                return z3.unsat