import abc import itertools #from typing import Dict, Hashable, List, Literal, Optional, Sequence, Tuple from typing import Dict, Hashable, List, Optional, Sequence, Tuple import gurobipy as gp import numpy as np import sympy import z3 import dreal def z3_float64_const_to_real(v: float) -> z3.RatNumRef: return z3.simplify( z3.fpToReal(z3.FPVal(v, z3.Float64())) ) class TeacherBase(abc.ABC): @staticmethod def is_spurious_example(state_dim: int, perc_dim: int, candidate: z3.BoolRef, cex: Tuple[float, ...]) -> bool: state_subs_map = [(z3.Real(f"x_{i}"), z3_float64_const_to_real(cex[i])) for i in range(state_dim)] perc_subs_map = [(z3.Real(f"z_{i}"), z3_float64_const_to_real(cex[i+state_dim])) for i in range(perc_dim)] sub_map = state_subs_map + perc_subs_map val = z3.simplify(z3.substitute(candidate, *sub_map)) assert z3.is_bool(val) if z3.is_false(val): return True elif z3.is_true(val): return False else: raise RuntimeError(f"Cannot validate negative example {cex} by substitution") def __init__(self) -> None: pass @abc.abstractmethod def is_positive_example(self, ex) -> Optional[bool]: raise NotImplementedError @abc.abstractmethod def check(self, candidate) -> z3.CheckSatResult: raise NotImplementedError @abc.abstractmethod def model(self): raise NotImplementedError @abc.abstractmethod def reason_unknown(self) -> str: raise NotImplementedError @abc.abstractmethod def dump_system_encoding(self, basename: str = "") -> None: raise NotImplementedError class DRealTeacherBase(TeacherBase): @staticmethod def affine_trans_exprs(state, coeff: np.ndarray, intercept=None): if intercept is None: intercept = np.zeros(coeff.shape[0]) assert (len(intercept), len(state)) == coeff.shape return [ b_i + sum([col_j*var_j for col_j, var_j in zip(row_i, state)]) for row_i, b_i in zip(coeff, intercept) ] def __init__(self, name: str, state_dim: int, perc_dim: int, ctrl_dim: int, norm_ord, delta: float = 0.001) -> None: self._norm_ord = norm_ord self._delta = delta # Old state variables self._old_state = [dreal.Variable(f"x[{i}]", dreal.Variable.Real) for i in range(state_dim)] # New state variables self._new_state = [dreal.Variable(f"x'[{i}]", dreal.Variable.Real) for i in range(state_dim)] # Perception variables self._percept = [dreal.Variable(f"z[{i}]", dreal.Variable.Real) for i in range(perc_dim)] # Control variables self._control = [dreal.Variable(f"u[{i}]", dreal.Variable.Real) for i in range(ctrl_dim)] self._var_bounds = { var: (-np.inf, np.inf) for var in itertools.chain(self._old_state, self._percept, self._control) } # type: Dict[dreal.Variable, Tuple[float, float]] self._not_inv_cons = [] # type: List self._models = [] # type: List self._add_system() self._add_unsafe() @abc.abstractmethod def _add_system(self) -> None: raise NotImplementedError @abc.abstractmethod def _add_unsafe(self) -> None: raise NotImplementedError @property def state_dim(self) -> int: return len(self._old_state) @property def perc_dim(self) -> int: return len(self._percept) @property def ctrl_dim(self) -> int: return len(self._control) def _gen_cand_pred(self, candidate): raise NotImplementedError(f"TODO convert {candidate} to dReal formula") def _set_var_bound(self, smt_vars: List[dreal.Variable], lb: Sequence[float], ub: Sequence[float]) -> None: assert len(lb) == len(smt_vars) assert len(ub) == len(smt_vars) for var_i, lb_i, ub_i in zip(smt_vars, lb, ub): self._var_bounds[var_i] = (lb_i, ub_i) def set_old_state_bound(self, lb: Sequence[float], ub: Sequence[float]) -> None: self._set_var_bound(self._old_state, lb, ub) def check(self, candidate) -> z3.CheckSatResult: bound_preds = [ dreal.And(lb_i <= var_i, var_i <= ub_i) for var_i, (lb_i, ub_i) in self._var_bounds.items() ] cand_pred = self._gen_cand_pred(candidate) query = dreal.And(*bound_preds, cand_pred, *self._not_inv_cons) res = dreal.CheckSatisfiability(query, self._delta) self._models.clear() if res: self._models.append( tuple(res[x_i] for x_i in self._old_state) + tuple(res[z_i] for z_i in self._percept) ) return z3.sat else: return z3.unsat def dump_system_encoding(self, basename: str = "") -> None: print([ dreal.And(lb_i <= var_i, var_i <= ub_i) for var_i, (lb_i, ub_i) in self._var_bounds.items() ]) print(self._not_inv_cons) def model(self) -> Sequence[Tuple]: return self._models def reason_unknown(self) -> str: raise NotImplementedError class GurobiTeacherBase(TeacherBase): FREEVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": -np.inf, "ub": np.inf} NNEGVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": 0.0, "ub": np.inf} TRIGVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": -1.0, "ub": 1.0} def __init__(self, name: str, state_dim: int, perc_dim: int, ctrl_dim: int, norm_ord = 2) -> None: super().__init__() self._gp_model = gp.Model(name) m = self._gp_model if norm_ord not in [1, 2, 'inf']: raise ValueError("Norm order %s is not supported." % str(norm_ord)) if norm_ord == 2: self._gp_model.setParam("NonConvex", 2) self._norm_ord = norm_ord # Old state variables self._old_state = m.addMVar(shape=state_dim, name='x', **self.FREEVAR) # New state variables self._new_state = m.addMVar(shape=state_dim, name="x'", **self.FREEVAR) # Perception variables self._percept = m.addMVar(shape=perc_dim, name='z', **self.FREEVAR) # Control variables self._control = m.addMVar(shape=ctrl_dim, name='u', **self.FREEVAR) # FIXME Replace hardcoded contraints for this particular example self._add_system() self._add_unsafe() # Add intermediate variables with constraints for candidates self._percept_diff = m.addMVar(name="z-m(x)", shape=perc_dim, **self.FREEVAR) self._add_objective() self._prev_candidate_constr: List = [] self._cexs: List = [] @property def state_dim(self) -> int: return self._old_state.shape[0] @property def perc_dim(self) -> int: return self._percept.shape[0] @property def ctrl_dim(self) -> int: return self._control.shape[0] @abc.abstractmethod def _add_system(self) -> None: raise NotImplementedError @abc.abstractmethod def _add_unsafe(self) -> None: raise NotImplementedError @abc.abstractmethod def _add_objective(self) -> None: raise NotImplementedError def set_old_state_bound(self, lb, ub) -> None: self._old_state.lb = lb self._old_state.ub = ub self._gp_model.update() def check(self, candidate) -> z3.CheckSatResult: raise NotImplementedError def dump_system_encoding(self, basename: str = "") -> None: """ Dump optimization problem in LP format """ if not basename: basename = self._gp_model.ModelName self._gp_model.write(basename + ".lp") def model(self) -> Sequence[Tuple]: return self._cexs def reason_unknown(self) -> str: if self._gp_model.status == gp.GRB.INF_OR_UNBD: return '(model is infeasible or unbounded)' elif self._gp_model.status == gp.GRB.UNBOUNDED: return '(model is unbounded)' else: # TODO other status code return '(gurobi model status code %d)' % self._gp_model.status class SymPyTeacherBase(TeacherBase): def __init__(self, name: str, state_dim: int, perc_dim: int, ctrl_dim: int, norm_ord) -> None: self._norm_ord = norm_ord # Old state variables self._old_state = sympy.symbols(names=[f"x_{i}" for i in range(state_dim)], real=True) # New state variables self._new_state = sympy.symbols(names=[f"x'_{i}" for i in range(state_dim)], real=True) # Perception variables self._percept = sympy.symbols(names=[f"z_{i}" for i in range(perc_dim)], real=True) # Control variables self._control = sympy.symbols(names=[f"u_{i}" for i in range(ctrl_dim)], real=True) self._var_bounds = { var: (-sympy.oo, sympy.oo) for var in itertools.chain(self._old_state, self._new_state, self._percept, self._control) } # type: Dict[dreal.Variable, Tuple[float, float]] self._not_inv_cons = [] # type: List self._add_system() self._add_unsafe() @abc.abstractmethod def _add_system(self) -> None: raise NotImplementedError @abc.abstractmethod def _add_unsafe(self) -> None: raise NotImplementedError @property def state_dim(self) -> int: return len(self._old_state) @property def perc_dim(self) -> int: return len(self._percept) @property def ctrl_dim(self) -> int: return len(self._control) def _set_var_bound(self, sym_vars: List[dreal.Variable], lb: Sequence[float], ub: Sequence[float]) -> None: assert len(lb) == len(sym_vars) assert len(ub) == len(sym_vars) for var_i, lb_i, ub_i in zip(sym_vars, lb, ub): self._var_bounds[var_i] = (lb_i, ub_i) def set_old_state_bound(self, lb: Sequence[float], ub: Sequence[float]) -> None: self._set_var_bound(self._old_state, lb, ub) def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult: raise NotImplementedError def dump_system_encoding(self, basename: str = "") -> None: query = sympy.And( *(sympy.And(lb_i <= var_i, var_i <= ub_i) for var_i, (lb_i, ub_i) in self._var_bounds.items()), *self._not_inv_cons) print(sympy.simplify(query)) def model(self) -> Sequence[Tuple]: raise NotImplementedError def reason_unknown(self) -> str: raise NotImplementedError