teacher_base.py 10.39 KiB
import abc
import itertools
#from typing import Dict, Hashable, List, Literal, Optional, Sequence, Tuple
from typing import Dict, Hashable, List, Optional, Sequence, Tuple
import gurobipy as gp
import numpy as np
import sympy
import z3
import dreal
def z3_float64_const_to_real(v: float) -> z3.RatNumRef:
return z3.simplify(
z3.fpToReal(z3.FPVal(v, z3.Float64()))
)
class TeacherBase(abc.ABC):
@staticmethod
def is_spurious_example(state_dim: int, perc_dim: int, candidate: z3.BoolRef, cex: Tuple[float, ...]) -> bool:
state_subs_map = [(z3.Real(f"x_{i}"), z3_float64_const_to_real(cex[i])) for i in range(state_dim)]
perc_subs_map = [(z3.Real(f"z_{i}"), z3_float64_const_to_real(cex[i+state_dim])) for i in range(perc_dim)]
sub_map = state_subs_map + perc_subs_map
val = z3.simplify(z3.substitute(candidate, *sub_map))
assert z3.is_bool(val)
if z3.is_false(val):
return True
elif z3.is_true(val):
return False
else:
raise RuntimeError(f"Cannot validate negative example {cex} by substitution")
def __init__(self) -> None:
pass
@abc.abstractmethod
def is_positive_example(self, ex) -> Optional[bool]:
raise NotImplementedError
@abc.abstractmethod
def check(self, candidate) -> z3.CheckSatResult:
raise NotImplementedError
@abc.abstractmethod
def model(self):
raise NotImplementedError
@abc.abstractmethod
def reason_unknown(self) -> str:
raise NotImplementedError
@abc.abstractmethod
def dump_system_encoding(self, basename: str = "") -> None:
raise NotImplementedError
class DRealTeacherBase(TeacherBase):
@staticmethod
def affine_trans_exprs(state, coeff: np.ndarray, intercept=None):
if intercept is None:
intercept = np.zeros(coeff.shape[0])
assert (len(intercept), len(state)) == coeff.shape
return [
b_i + sum([col_j*var_j for col_j, var_j in zip(row_i, state)])
for row_i, b_i in zip(coeff, intercept)
]
def __init__(self, name: str,
state_dim: int, perc_dim: int, ctrl_dim: int, norm_ord,
delta: float = 0.001) -> None:
self._norm_ord = norm_ord
self._delta = delta
# Old state variables
self._old_state = [dreal.Variable(f"x[{i}]", dreal.Variable.Real) for i in range(state_dim)]
# New state variables
self._new_state = [dreal.Variable(f"x'[{i}]", dreal.Variable.Real) for i in range(state_dim)]
# Perception variables
self._percept = [dreal.Variable(f"z[{i}]", dreal.Variable.Real) for i in range(perc_dim)]
# Control variables
self._control = [dreal.Variable(f"u[{i}]", dreal.Variable.Real) for i in range(ctrl_dim)]
self._var_bounds = {
var: (-np.inf, np.inf)
for var in itertools.chain(self._old_state, self._percept, self._control)
} # type: Dict[dreal.Variable, Tuple[float, float]]
self._not_inv_cons = [] # type: List
self._models = [] # type: List
self._add_system()
self._add_unsafe()
@abc.abstractmethod
def _add_system(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def _add_unsafe(self) -> None:
raise NotImplementedError
@property
def state_dim(self) -> int:
return len(self._old_state)
@property
def perc_dim(self) -> int:
return len(self._percept)
@property
def ctrl_dim(self) -> int:
return len(self._control)
def _gen_cand_pred(self, candidate):
raise NotImplementedError(f"TODO convert {candidate} to dReal formula")
def _set_var_bound(self, smt_vars: List[dreal.Variable], lb: Sequence[float], ub: Sequence[float]) -> None:
assert len(lb) == len(smt_vars)
assert len(ub) == len(smt_vars)
for var_i, lb_i, ub_i in zip(smt_vars, lb, ub):
self._var_bounds[var_i] = (lb_i, ub_i)
def set_old_state_bound(self, lb: Sequence[float], ub: Sequence[float]) -> None:
self._set_var_bound(self._old_state, lb, ub)
def check(self, candidate) -> z3.CheckSatResult:
bound_preds = [
dreal.And(lb_i <= var_i, var_i <= ub_i)
for var_i, (lb_i, ub_i) in self._var_bounds.items()
]
cand_pred = self._gen_cand_pred(candidate)
query = dreal.And(*bound_preds, cand_pred, *self._not_inv_cons)
res = dreal.CheckSatisfiability(query, self._delta)
self._models.clear()
if res:
self._models.append(
tuple(res[x_i] for x_i in self._old_state) + tuple(res[z_i] for z_i in self._percept)
)
return z3.sat
else:
return z3.unsat
def dump_system_encoding(self, basename: str = "") -> None:
print([
dreal.And(lb_i <= var_i, var_i <= ub_i)
for var_i, (lb_i, ub_i) in self._var_bounds.items()
])
print(self._not_inv_cons)
def model(self) -> Sequence[Tuple]:
return self._models
def reason_unknown(self) -> str:
raise NotImplementedError
class GurobiTeacherBase(TeacherBase):
FREEVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": -np.inf, "ub": np.inf}
NNEGVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": 0.0, "ub": np.inf}
TRIGVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": -1.0, "ub": 1.0}
def __init__(self, name: str,
state_dim: int, perc_dim: int, ctrl_dim: int, norm_ord = 2) -> None:
super().__init__()
self._gp_model = gp.Model(name)
m = self._gp_model
if norm_ord not in [1, 2, 'inf']:
raise ValueError("Norm order %s is not supported." % str(norm_ord))
if norm_ord == 2:
self._gp_model.setParam("NonConvex", 2)
self._norm_ord = norm_ord
# Old state variables
self._old_state = m.addMVar(shape=state_dim, name='x', **self.FREEVAR)
# New state variables
self._new_state = m.addMVar(shape=state_dim, name="x'", **self.FREEVAR)
# Perception variables
self._percept = m.addMVar(shape=perc_dim, name='z', **self.FREEVAR)
# Control variables
self._control = m.addMVar(shape=ctrl_dim, name='u', **self.FREEVAR)
# FIXME Replace hardcoded contraints for this particular example
self._add_system()
self._add_unsafe()
# Add intermediate variables with constraints for candidates
self._percept_diff = m.addMVar(name="z-m(x)", shape=perc_dim, **self.FREEVAR)
self._add_objective()
self._prev_candidate_constr: List = []
self._cexs: List = []
@property
def state_dim(self) -> int:
return self._old_state.shape[0]
@property
def perc_dim(self) -> int:
return self._percept.shape[0]
@property
def ctrl_dim(self) -> int:
return self._control.shape[0]
@abc.abstractmethod
def _add_system(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def _add_unsafe(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def _add_objective(self) -> None:
raise NotImplementedError
def set_old_state_bound(self, lb, ub) -> None:
self._old_state.lb = lb
self._old_state.ub = ub
self._gp_model.update()
def check(self, candidate) -> z3.CheckSatResult:
raise NotImplementedError
def dump_system_encoding(self, basename: str = "") -> None:
""" Dump optimization problem in LP format """
if not basename:
basename = self._gp_model.ModelName
self._gp_model.write(basename + ".lp")
def model(self) -> Sequence[Tuple]:
return self._cexs
def reason_unknown(self) -> str:
if self._gp_model.status == gp.GRB.INF_OR_UNBD:
return '(model is infeasible or unbounded)'
elif self._gp_model.status == gp.GRB.UNBOUNDED:
return '(model is unbounded)'
else: # TODO other status code
return '(gurobi model status code %d)' % self._gp_model.status
class SymPyTeacherBase(TeacherBase):
def __init__(self, name: str,
state_dim: int, perc_dim: int, ctrl_dim: int, norm_ord) -> None:
self._norm_ord = norm_ord
# Old state variables
self._old_state = sympy.symbols(names=[f"x_{i}" for i in range(state_dim)], real=True)
# New state variables
self._new_state = sympy.symbols(names=[f"x'_{i}" for i in range(state_dim)], real=True)
# Perception variables
self._percept = sympy.symbols(names=[f"z_{i}" for i in range(perc_dim)], real=True)
# Control variables
self._control = sympy.symbols(names=[f"u_{i}" for i in range(ctrl_dim)], real=True)
self._var_bounds = {
var: (-sympy.oo, sympy.oo)
for var in itertools.chain(self._old_state, self._new_state, self._percept, self._control)
} # type: Dict[dreal.Variable, Tuple[float, float]]
self._not_inv_cons = [] # type: List
self._add_system()
self._add_unsafe()
@abc.abstractmethod
def _add_system(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def _add_unsafe(self) -> None:
raise NotImplementedError
@property
def state_dim(self) -> int:
return len(self._old_state)
@property
def perc_dim(self) -> int:
return len(self._percept)
@property
def ctrl_dim(self) -> int:
return len(self._control)
def _set_var_bound(self, sym_vars: List[dreal.Variable], lb: Sequence[float], ub: Sequence[float]) -> None:
assert len(lb) == len(sym_vars)
assert len(ub) == len(sym_vars)
for var_i, lb_i, ub_i in zip(sym_vars, lb, ub):
self._var_bounds[var_i] = (lb_i, ub_i)
def set_old_state_bound(self, lb: Sequence[float], ub: Sequence[float]) -> None:
self._set_var_bound(self._old_state, lb, ub)
def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
raise NotImplementedError
def dump_system_encoding(self, basename: str = "") -> None:
query = sympy.And(
*(sympy.And(lb_i <= var_i, var_i <= ub_i)
for var_i, (lb_i, ub_i) in self._var_bounds.items()),
*self._not_inv_cons)
print(sympy.simplify(query))
def model(self) -> Sequence[Tuple]:
raise NotImplementedError
def reason_unknown(self) -> str:
raise NotImplementedError