Skip to content
Snippets Groups Projects
teacher_base.py 10.4 KiB
Newer Older
  • Learn to ignore specific revisions
  • import abc
    
    chsieh16's avatar
    chsieh16 committed
    import itertools
    
    aastorg2's avatar
    aastorg2 committed
    #from typing import Dict, Hashable, List, Literal, Optional, Sequence, Tuple
    from typing import Dict, Hashable, List, Optional, Sequence, Tuple
    
    chsieh16's avatar
    chsieh16 committed
    import gurobipy as gp
    import numpy as np
    
    import z3
    
    chsieh16's avatar
    chsieh16 committed
    import dreal
    
    def z3_float64_const_to_real(v: float) -> z3.RatNumRef:
        return z3.simplify(
            z3.fpToReal(z3.FPVal(v, z3.Float64()))
        )
    
    
    
    class TeacherBase(abc.ABC):
    
    
        @staticmethod
        def is_spurious_example(state_dim: int, perc_dim: int, candidate: z3.BoolRef, cex: Tuple[float, ...]) -> bool:
            state_subs_map = [(z3.Real(f"x_{i}"), z3_float64_const_to_real(cex[i])) for i in range(state_dim)]
            perc_subs_map = [(z3.Real(f"z_{i}"), z3_float64_const_to_real(cex[i+state_dim])) for i in range(perc_dim)]
            sub_map = state_subs_map + perc_subs_map
            val = z3.simplify(z3.substitute(candidate, *sub_map))
            assert z3.is_bool(val)
            if z3.is_false(val):
                return True
            elif z3.is_true(val):
                return False
            else:
                raise RuntimeError(f"Cannot validate negative example {cex} by substitution")
    
    
        def __init__(self) -> None:
            pass
    
    
        @abc.abstractmethod
    
        def is_positive_example(self, ex) -> Optional[bool]:
    
            raise NotImplementedError
    
    
        @abc.abstractmethod
    
        def check(self, candidate) -> z3.CheckSatResult:
    
            raise NotImplementedError
    
        @abc.abstractmethod
    
        def model(self):
    
            raise NotImplementedError
    
        @abc.abstractmethod
        def reason_unknown(self) -> str:
            raise NotImplementedError
    
    chsieh16's avatar
    chsieh16 committed
        @abc.abstractmethod
        def dump_system_encoding(self, basename: str = "") -> None:
            raise NotImplementedError
    
    
    chsieh16's avatar
    chsieh16 committed
    class DRealTeacherBase(TeacherBase):
        @staticmethod
        def affine_trans_exprs(state, coeff: np.ndarray, intercept=None):
            if intercept is None:
                intercept = np.zeros(coeff.shape[0])
            assert (len(intercept), len(state)) == coeff.shape
            return [
                b_i + sum([col_j*var_j for col_j, var_j in zip(row_i, state)])
                for row_i, b_i in zip(coeff, intercept)
            ]
    
        def __init__(self, name: str,
    
    aastorg2's avatar
    aastorg2 committed
                     state_dim: int, perc_dim: int, ctrl_dim: int, norm_ord,
    
    chsieh16's avatar
    chsieh16 committed
                     delta: float = 0.001) -> None:
            self._norm_ord = norm_ord
            self._delta = delta
    
            # Old state variables
            self._old_state = [dreal.Variable(f"x[{i}]", dreal.Variable.Real) for i in range(state_dim)]
            # New state variables
            self._new_state = [dreal.Variable(f"x'[{i}]", dreal.Variable.Real) for i in range(state_dim)]
            # Perception variables
            self._percept = [dreal.Variable(f"z[{i}]", dreal.Variable.Real) for i in range(perc_dim)]
            # Control variables
            self._control = [dreal.Variable(f"u[{i}]", dreal.Variable.Real) for i in range(ctrl_dim)]
    
            self._var_bounds = {
                var: (-np.inf, np.inf)
                for var in itertools.chain(self._old_state, self._percept, self._control)
            }  # type: Dict[dreal.Variable, Tuple[float, float]]
            self._not_inv_cons = []  # type: List
            self._models = []  # type: List
    
            self._add_system()
            self._add_unsafe()
    
        @abc.abstractmethod
        def _add_system(self) -> None:
            raise NotImplementedError
    
        @abc.abstractmethod
        def _add_unsafe(self) -> None:
            raise NotImplementedError
    
        @property
        def state_dim(self) -> int:
            return len(self._old_state)
    
        @property
        def perc_dim(self) -> int:
            return len(self._percept)
    
        @property
        def ctrl_dim(self) -> int:
            return len(self._control)
    
    
        def _gen_cand_pred(self, candidate):
            raise NotImplementedError(f"TODO convert {candidate} to dReal formula")
    
    chsieh16's avatar
    chsieh16 committed
    
        def _set_var_bound(self, smt_vars: List[dreal.Variable], lb: Sequence[float], ub: Sequence[float]) -> None:
            assert len(lb) == len(smt_vars)
            assert len(ub) == len(smt_vars)
            for var_i, lb_i, ub_i in zip(smt_vars, lb, ub):
                self._var_bounds[var_i] = (lb_i, ub_i)
    
        def set_old_state_bound(self, lb: Sequence[float], ub: Sequence[float]) -> None:
            self._set_var_bound(self._old_state, lb, ub)
    
    
        def check(self, candidate) -> z3.CheckSatResult:
    
    chsieh16's avatar
    chsieh16 committed
            bound_preds = [
                dreal.And(lb_i <= var_i, var_i <= ub_i)
                for var_i, (lb_i, ub_i) in self._var_bounds.items()
            ]
            cand_pred = self._gen_cand_pred(candidate)
    
            query = dreal.And(*bound_preds, cand_pred, *self._not_inv_cons)
    
    chsieh16's avatar
    chsieh16 committed
            res = dreal.CheckSatisfiability(query, self._delta)
    
            self._models.clear()
            if res:
                self._models.append(
                    tuple(res[x_i] for x_i in self._old_state) + tuple(res[z_i] for z_i in self._percept)
                )
                return z3.sat
            else:
    
    chsieh16's avatar
    chsieh16 committed
        def dump_system_encoding(self, basename: str = "") -> None:
    
    chsieh16's avatar
    chsieh16 committed
            print([
                dreal.And(lb_i <= var_i, var_i <= ub_i)
                for var_i, (lb_i, ub_i) in self._var_bounds.items()
            ])
            print(self._not_inv_cons)
    
        def model(self) -> Sequence[Tuple]:
            return self._models
    
        def reason_unknown(self) -> str:
            raise NotImplementedError
    
    
    
    chsieh16's avatar
    chsieh16 committed
    class GurobiTeacherBase(TeacherBase):
    
    chsieh16's avatar
    chsieh16 committed
        FREEVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": -np.inf, "ub": np.inf}
        NNEGVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": 0.0, "ub": np.inf}
    
    chsieh16's avatar
    chsieh16 committed
        TRIGVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": -1.0, "ub": 1.0}
    
    chsieh16's avatar
    chsieh16 committed
        def __init__(self, name: str,
    
    aastorg2's avatar
    aastorg2 committed
                     state_dim: int, perc_dim: int, ctrl_dim: int, norm_ord = 2) -> None:
    
    chsieh16's avatar
    chsieh16 committed
            super().__init__()
    
            self._gp_model = gp.Model(name)
    
    chsieh16's avatar
    chsieh16 committed
            m = self._gp_model
    
    
            if norm_ord not in [1, 2, 'inf']:
                raise ValueError("Norm order %s is not supported." % str(norm_ord))
            if norm_ord == 2:
                self._gp_model.setParam("NonConvex", 2)
            self._norm_ord = norm_ord
    
    
    chsieh16's avatar
    chsieh16 committed
            # Old state variables
    
    chsieh16's avatar
    chsieh16 committed
            self._old_state = m.addMVar(shape=state_dim, name='x', **self.FREEVAR)
    
    chsieh16's avatar
    chsieh16 committed
            # New state variables
    
    chsieh16's avatar
    chsieh16 committed
            self._new_state = m.addMVar(shape=state_dim, name="x'", **self.FREEVAR)
    
    chsieh16's avatar
    chsieh16 committed
            # Perception variables
    
    chsieh16's avatar
    chsieh16 committed
            self._percept = m.addMVar(shape=perc_dim, name='z', **self.FREEVAR)
    
    chsieh16's avatar
    chsieh16 committed
            # Control variables
    
    chsieh16's avatar
    chsieh16 committed
            self._control = m.addMVar(shape=ctrl_dim, name='u', **self.FREEVAR)
    
    chsieh16's avatar
    chsieh16 committed
    
            # FIXME Replace hardcoded contraints for this particular example
            self._add_system()
            self._add_unsafe()
    
    
            # Add intermediate variables with constraints for candidates
    
            self._percept_diff = m.addMVar(name="z-m(x)", shape=perc_dim, **self.FREEVAR)
    
            self._add_objective()
    
    
            self._prev_candidate_constr: List = []
    
            self._cexs: List = []
    
    chsieh16's avatar
    chsieh16 committed
        @property
        def state_dim(self) -> int:
            return self._old_state.shape[0]
    
        @property
        def perc_dim(self) -> int:
            return self._percept.shape[0]
    
        @property
        def ctrl_dim(self) -> int:
            return self._control.shape[0]
    
    
    chsieh16's avatar
    chsieh16 committed
        @abc.abstractmethod
        def _add_system(self) -> None:
            raise NotImplementedError
    
        @abc.abstractmethod
        def _add_unsafe(self) -> None:
            raise NotImplementedError
    
    
        @abc.abstractmethod
        def _add_objective(self) -> None:
            raise NotImplementedError
    
    chsieh16's avatar
    chsieh16 committed
    
        def set_old_state_bound(self, lb, ub) -> None:
            self._old_state.lb = lb
            self._old_state.ub = ub
    
            self._gp_model.update()
    
        def check(self, candidate) -> z3.CheckSatResult:
    
            raise NotImplementedError
    
    chsieh16's avatar
    chsieh16 committed
        def dump_system_encoding(self, basename: str = "") -> None:
    
    chsieh16's avatar
    chsieh16 committed
            """ Dump optimization problem in LP format """
            if not basename:
                basename = self._gp_model.ModelName
            self._gp_model.write(basename + ".lp")
    
    
        def model(self) -> Sequence[Tuple]:
    
            return self._cexs
    
    chsieh16's avatar
    chsieh16 committed
    
        def reason_unknown(self) -> str:
            if self._gp_model.status == gp.GRB.INF_OR_UNBD:
                return '(model is infeasible or unbounded)'
            elif self._gp_model.status == gp.GRB.UNBOUNDED:
                return '(model is unbounded)'
            else:  # TODO other status code
                return '(gurobi model status code %d)' % self._gp_model.status
    
    
    
    class SymPyTeacherBase(TeacherBase):
        def __init__(self, name: str,
    
    aastorg2's avatar
    aastorg2 committed
                     state_dim: int, perc_dim: int, ctrl_dim: int, norm_ord) -> None:
    
            self._norm_ord = norm_ord
    
            # Old state variables
    
    chsieh16's avatar
    chsieh16 committed
            self._old_state = sympy.symbols(names=[f"x_{i}" for i in range(state_dim)], real=True)
    
    chsieh16's avatar
    chsieh16 committed
            self._new_state = sympy.symbols(names=[f"x'_{i}" for i in range(state_dim)], real=True)
    
    chsieh16's avatar
    chsieh16 committed
            self._percept = sympy.symbols(names=[f"z_{i}" for i in range(perc_dim)], real=True)
    
    chsieh16's avatar
    chsieh16 committed
            self._control = sympy.symbols(names=[f"u_{i}" for i in range(ctrl_dim)], real=True)
    
    
            self._var_bounds = {
                var: (-sympy.oo, sympy.oo)
                for var in itertools.chain(self._old_state, self._new_state, self._percept, self._control)
            }  # type: Dict[dreal.Variable, Tuple[float, float]]
    
            self._not_inv_cons = []  # type: List
    
            self._add_system()
            self._add_unsafe()
    
        @abc.abstractmethod
        def _add_system(self) -> None:
            raise NotImplementedError
    
        @abc.abstractmethod
        def _add_unsafe(self) -> None:
            raise NotImplementedError
    
        @property
        def state_dim(self) -> int:
            return len(self._old_state)
    
        @property
        def perc_dim(self) -> int:
            return len(self._percept)
    
        @property
        def ctrl_dim(self) -> int:
            return len(self._control)
    
        def _set_var_bound(self, sym_vars: List[dreal.Variable], lb: Sequence[float], ub: Sequence[float]) -> None:
            assert len(lb) == len(sym_vars)
            assert len(ub) == len(sym_vars)
            for var_i, lb_i, ub_i in zip(sym_vars, lb, ub):
                self._var_bounds[var_i] = (lb_i, ub_i)
    
        def set_old_state_bound(self, lb: Sequence[float], ub: Sequence[float]) -> None:
            self._set_var_bound(self._old_state, lb, ub)
    
        def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
            raise NotImplementedError
    
    
    chsieh16's avatar
    chsieh16 committed
        def dump_system_encoding(self, basename: str = "") -> None:
    
            query = sympy.And(
                *(sympy.And(lb_i <= var_i, var_i <= ub_i)
                  for var_i, (lb_i, ub_i) in self._var_bounds.items()),
                *self._not_inv_cons)
            print(sympy.simplify(query))
    
        def model(self) -> Sequence[Tuple]:
            raise NotImplementedError
    
        def reason_unknown(self) -> str:
            raise NotImplementedError