Skip to content
Snippets Groups Projects
teacher_base.py 10.39 KiB
import abc
import itertools
#from typing import Dict, Hashable, List, Literal, Optional, Sequence, Tuple
from typing import Dict, Hashable, List, Optional, Sequence, Tuple

import gurobipy as gp
import numpy as np
import sympy
import z3
import dreal


def z3_float64_const_to_real(v: float) -> z3.RatNumRef:
    return z3.simplify(
        z3.fpToReal(z3.FPVal(v, z3.Float64()))
    )


class TeacherBase(abc.ABC):

    @staticmethod
    def is_spurious_example(state_dim: int, perc_dim: int, candidate: z3.BoolRef, cex: Tuple[float, ...]) -> bool:
        state_subs_map = [(z3.Real(f"x_{i}"), z3_float64_const_to_real(cex[i])) for i in range(state_dim)]
        perc_subs_map = [(z3.Real(f"z_{i}"), z3_float64_const_to_real(cex[i+state_dim])) for i in range(perc_dim)]
        sub_map = state_subs_map + perc_subs_map
        val = z3.simplify(z3.substitute(candidate, *sub_map))
        assert z3.is_bool(val)
        if z3.is_false(val):
            return True
        elif z3.is_true(val):
            return False
        else:
            raise RuntimeError(f"Cannot validate negative example {cex} by substitution")

    def __init__(self) -> None:
        pass

    @abc.abstractmethod
    def is_positive_example(self, ex) -> Optional[bool]:
        raise NotImplementedError

    @abc.abstractmethod
    def check(self, candidate) -> z3.CheckSatResult:
        raise NotImplementedError

    @abc.abstractmethod
    def model(self):
        raise NotImplementedError

    @abc.abstractmethod
    def reason_unknown(self) -> str:
        raise NotImplementedError

    @abc.abstractmethod
    def dump_system_encoding(self, basename: str = "") -> None:
        raise NotImplementedError


class DRealTeacherBase(TeacherBase):
    @staticmethod
    def affine_trans_exprs(state, coeff: np.ndarray, intercept=None):
        if intercept is None:
            intercept = np.zeros(coeff.shape[0])
        assert (len(intercept), len(state)) == coeff.shape
        return [
            b_i + sum([col_j*var_j for col_j, var_j in zip(row_i, state)])
            for row_i, b_i in zip(coeff, intercept)
        ]

    def __init__(self, name: str,
                 state_dim: int, perc_dim: int, ctrl_dim: int, norm_ord,
                 delta: float = 0.001) -> None:
        self._norm_ord = norm_ord
        self._delta = delta

        # Old state variables
        self._old_state = [dreal.Variable(f"x[{i}]", dreal.Variable.Real) for i in range(state_dim)]
        # New state variables
        self._new_state = [dreal.Variable(f"x'[{i}]", dreal.Variable.Real) for i in range(state_dim)]
        # Perception variables
        self._percept = [dreal.Variable(f"z[{i}]", dreal.Variable.Real) for i in range(perc_dim)]
        # Control variables
        self._control = [dreal.Variable(f"u[{i}]", dreal.Variable.Real) for i in range(ctrl_dim)]

        self._var_bounds = {
            var: (-np.inf, np.inf)
            for var in itertools.chain(self._old_state, self._percept, self._control)
        }  # type: Dict[dreal.Variable, Tuple[float, float]]
        self._not_inv_cons = []  # type: List
        self._models = []  # type: List

        self._add_system()
        self._add_unsafe()

    @abc.abstractmethod
    def _add_system(self) -> None:
        raise NotImplementedError

    @abc.abstractmethod
    def _add_unsafe(self) -> None:
        raise NotImplementedError

    @property
    def state_dim(self) -> int:
        return len(self._old_state)

    @property
    def perc_dim(self) -> int:
        return len(self._percept)

    @property
    def ctrl_dim(self) -> int:
        return len(self._control)

    def _gen_cand_pred(self, candidate):
        raise NotImplementedError(f"TODO convert {candidate} to dReal formula")

    def _set_var_bound(self, smt_vars: List[dreal.Variable], lb: Sequence[float], ub: Sequence[float]) -> None:
        assert len(lb) == len(smt_vars)
        assert len(ub) == len(smt_vars)
        for var_i, lb_i, ub_i in zip(smt_vars, lb, ub):
            self._var_bounds[var_i] = (lb_i, ub_i)

    def set_old_state_bound(self, lb: Sequence[float], ub: Sequence[float]) -> None:
        self._set_var_bound(self._old_state, lb, ub)

    def check(self, candidate) -> z3.CheckSatResult:
        bound_preds = [
            dreal.And(lb_i <= var_i, var_i <= ub_i)
            for var_i, (lb_i, ub_i) in self._var_bounds.items()
        ]
        cand_pred = self._gen_cand_pred(candidate)
        query = dreal.And(*bound_preds, cand_pred, *self._not_inv_cons)
        res = dreal.CheckSatisfiability(query, self._delta)

        self._models.clear()
        if res:
            self._models.append(
                tuple(res[x_i] for x_i in self._old_state) + tuple(res[z_i] for z_i in self._percept)
            )
            return z3.sat
        else:
            return z3.unsat

    def dump_system_encoding(self, basename: str = "") -> None:
        print([
            dreal.And(lb_i <= var_i, var_i <= ub_i)
            for var_i, (lb_i, ub_i) in self._var_bounds.items()
        ])
        print(self._not_inv_cons)

    def model(self) -> Sequence[Tuple]:
        return self._models

    def reason_unknown(self) -> str:
        raise NotImplementedError


class GurobiTeacherBase(TeacherBase):
    FREEVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": -np.inf, "ub": np.inf}
    NNEGVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": 0.0, "ub": np.inf}
    TRIGVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": -1.0, "ub": 1.0}

    def __init__(self, name: str,
                 state_dim: int, perc_dim: int, ctrl_dim: int, norm_ord = 2) -> None:
        super().__init__()

        self._gp_model = gp.Model(name)
        m = self._gp_model

        if norm_ord not in [1, 2, 'inf']:
            raise ValueError("Norm order %s is not supported." % str(norm_ord))
        if norm_ord == 2:
            self._gp_model.setParam("NonConvex", 2)
        self._norm_ord = norm_ord

        # Old state variables
        self._old_state = m.addMVar(shape=state_dim, name='x', **self.FREEVAR)
        # New state variables
        self._new_state = m.addMVar(shape=state_dim, name="x'", **self.FREEVAR)
        # Perception variables
        self._percept = m.addMVar(shape=perc_dim, name='z', **self.FREEVAR)
        # Control variables
        self._control = m.addMVar(shape=ctrl_dim, name='u', **self.FREEVAR)

        # FIXME Replace hardcoded contraints for this particular example
        self._add_system()
        self._add_unsafe()

        # Add intermediate variables with constraints for candidates
        self._percept_diff = m.addMVar(name="z-m(x)", shape=perc_dim, **self.FREEVAR)

        self._add_objective()

        self._prev_candidate_constr: List = []
        self._cexs: List = []

    @property
    def state_dim(self) -> int:
        return self._old_state.shape[0]

    @property
    def perc_dim(self) -> int:
        return self._percept.shape[0]

    @property
    def ctrl_dim(self) -> int:
        return self._control.shape[0]

    @abc.abstractmethod
    def _add_system(self) -> None:
        raise NotImplementedError

    @abc.abstractmethod
    def _add_unsafe(self) -> None:
        raise NotImplementedError

    @abc.abstractmethod
    def _add_objective(self) -> None:
        raise NotImplementedError

    def set_old_state_bound(self, lb, ub) -> None:
        self._old_state.lb = lb
        self._old_state.ub = ub
        self._gp_model.update()

    def check(self, candidate) -> z3.CheckSatResult:
        raise NotImplementedError

    def dump_system_encoding(self, basename: str = "") -> None:
        """ Dump optimization problem in LP format """
        if not basename:
            basename = self._gp_model.ModelName
        self._gp_model.write(basename + ".lp")

    def model(self) -> Sequence[Tuple]:
        return self._cexs

    def reason_unknown(self) -> str:
        if self._gp_model.status == gp.GRB.INF_OR_UNBD:
            return '(model is infeasible or unbounded)'
        elif self._gp_model.status == gp.GRB.UNBOUNDED:
            return '(model is unbounded)'
        else:  # TODO other status code
            return '(gurobi model status code %d)' % self._gp_model.status


class SymPyTeacherBase(TeacherBase):
    def __init__(self, name: str,
                 state_dim: int, perc_dim: int, ctrl_dim: int, norm_ord) -> None:
        self._norm_ord = norm_ord

        # Old state variables
        self._old_state = sympy.symbols(names=[f"x_{i}" for i in range(state_dim)], real=True)
        # New state variables
        self._new_state = sympy.symbols(names=[f"x'_{i}" for i in range(state_dim)], real=True)
        # Perception variables
        self._percept = sympy.symbols(names=[f"z_{i}" for i in range(perc_dim)], real=True)
        # Control variables
        self._control = sympy.symbols(names=[f"u_{i}" for i in range(ctrl_dim)], real=True)

        self._var_bounds = {
            var: (-sympy.oo, sympy.oo)
            for var in itertools.chain(self._old_state, self._new_state, self._percept, self._control)
        }  # type: Dict[dreal.Variable, Tuple[float, float]]

        self._not_inv_cons = []  # type: List

        self._add_system()
        self._add_unsafe()

    @abc.abstractmethod
    def _add_system(self) -> None:
        raise NotImplementedError

    @abc.abstractmethod
    def _add_unsafe(self) -> None:
        raise NotImplementedError

    @property
    def state_dim(self) -> int:
        return len(self._old_state)

    @property
    def perc_dim(self) -> int:
        return len(self._percept)

    @property
    def ctrl_dim(self) -> int:
        return len(self._control)

    def _set_var_bound(self, sym_vars: List[dreal.Variable], lb: Sequence[float], ub: Sequence[float]) -> None:
        assert len(lb) == len(sym_vars)
        assert len(ub) == len(sym_vars)
        for var_i, lb_i, ub_i in zip(sym_vars, lb, ub):
            self._var_bounds[var_i] = (lb_i, ub_i)

    def set_old_state_bound(self, lb: Sequence[float], ub: Sequence[float]) -> None:
        self._set_var_bound(self._old_state, lb, ub)

    def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
        raise NotImplementedError

    def dump_system_encoding(self, basename: str = "") -> None:
        query = sympy.And(
            *(sympy.And(lb_i <= var_i, var_i <= ub_i)
              for var_i, (lb_i, ub_i) in self._var_bounds.items()),
            *self._not_inv_cons)
        print(sympy.simplify(query))

    def model(self) -> Sequence[Tuple]:
        raise NotImplementedError

    def reason_unknown(self) -> str:
        raise NotImplementedError