import abc
from typing import List, Sequence, Tuple

import gurobipy as gp
import numpy as np
import z3


class TeacherBase(abc.ABC):
    def __init__(self) -> None:
        pass

    @abc.abstractmethod
    def check(self, candidate) -> z3.CheckSatResult:
        raise NotImplementedError

    @abc.abstractmethod
    def model(self):
        raise NotImplementedError

    @abc.abstractmethod
    def reason_unknown(self) -> str:
        raise NotImplementedError


class GurobiTeacherBase(TeacherBase):
    FREEVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": -np.inf, "ub": np.inf}
    NNEGVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": 0.0, "ub": np.inf}
    TRIGVAR = {"vtype": gp.GRB.CONTINUOUS, "lb": -1.0, "ub": 1.0}

    def __init__(self, name: str,
                 state_dim: int, perc_dim: int, ctrl_dim: int) -> None:
        super().__init__()

        self._gp_model = gp.Model(name)
        m = self._gp_model

        # Old state variables
        self._old_state = m.addMVar(shape=state_dim, name='x', **self.FREEVAR)
        # New state variables
        self._new_state = m.addMVar(shape=state_dim, name="x'", **self.FREEVAR)
        # Perception variables
        self._percept = m.addMVar(shape=perc_dim, name='z', **self.FREEVAR)
        # Control variables
        self._control = m.addMVar(shape=ctrl_dim, name='u', **self.FREEVAR)

        # FIXME Replace hardcoded contraints for this particular example
        self._add_system()
        self._add_unsafe()

        # Add intermediate variables with constraints for candidates
        z_diff = m.addMVar(name="z-(Ax+b)", shape=perc_dim, **self.FREEVAR)
        z_dist = m.addMVar(name="|z-(Ax+b)|", shape=perc_dim, **self.NNEGVAR)
        for zi, abs_zi in zip(z_diff.tolist(), z_dist.tolist()):
            m.addConstr(abs_zi == gp.abs_(zi))
        self._percept_diff = z_diff
        self._percept_dist = z_dist

        self._prev_candidate_constr: List = []

    @property
    def state_dim(self) -> int:
        return self._old_state.shape[0]

    @property
    def perc_dim(self) -> int:
        return self._percept.shape[0]

    @property
    def ctrl_dim(self) -> int:
        return self._control.shape[0]

    @abc.abstractmethod
    def _add_system(self) -> None:
        raise NotImplementedError

    @abc.abstractmethod
    def _add_unsafe(self) -> None:
        raise NotImplementedError

    def _set_candidate(self, candidate) -> None:
        shape_sel, coeff, intercept, radius = candidate  # TODO parse candidate
        # Variable Aliases
        m = self._gp_model
        x = self._old_state
        z = self._percept
        z_diff = self._percept_diff
        z_dist = self._percept_dist

        # TODO Modify coefficients instead of remove and add back contraints
        # Remove contraints from previous candidate first
        m.remove(self._prev_candidate_constr)
        self._prev_candidate_constr.clear()
        m.update()

        # Constraints on values of z
        cons = m.addConstr(z_diff == z - (coeff @ x + intercept))
        self._prev_candidate_constr.append(cons)

        if shape_sel == 0:
            cons_r = m.addConstr(z_dist.sum() <= radius, "||z-(Ax+b)||_1 <= r")
            self._prev_candidate_constr.append(cons_r)
            # L1-norm objective
            m.setObjective(z_dist.sum(), gp.GRB.MINIMIZE)
        elif shape_sel == 1:
            cons_r = m.addConstr(z_diff @ z_diff <= radius*radius,
                                 name="||z-(Ax+b)||^2 <= r^2")
            self._prev_candidate_constr.append(cons_r)
            # L2-norm objective
            m.setObjective(z_diff @ z_diff, gp.GRB.MINIMIZE)
        elif shape_sel == 2:
            dist_oo = m.addVar(name="||z-(Ax+b)||_oo", **self.NNEGVAR)
            cons_max = m.addConstr(dist_oo == gp.max_(*z_dist.tolist()))
            cons_r = m.addConstr(dist_oo <= radius, "||z-(Ax+b)||_oo <= r")
            self._prev_candidate_constr.extend([dist_oo, cons_max, cons_r])
            # Loo-norm objective
            m.setObjective(dist_oo, gp.GRB.MINIMIZE)
        else:
            raise ValueError("Unknown shape. shape_sel=%d" % shape_sel)

    def set_old_state_bound(self, lb, ub) -> None:
        self._old_state.lb = lb
        self._old_state.ub = ub

    def check(self, candidate) -> z3.CheckSatResult:
        self._set_candidate(candidate)
        self._gp_model.optimize()
        if self._gp_model.status in [gp.GRB.OPTIMAL, gp.GRB.SUBOPTIMAL]:
            return z3.sat
        elif self._gp_model.status == gp.GRB.INFEASIBLE:
            return z3.unsat
        else:
            return z3.unknown

    def dump_model(self, basename: str = "") -> None:
        """ Dump optimization problem in LP format """
        if not basename:
            basename = self._gp_model.ModelName
        self._gp_model.write(basename + ".lp")

    def model(self) -> Sequence[Tuple]:
        if self._gp_model.status in [gp.GRB.OPTIMAL, gp.GRB.SUBOPTIMAL]:
            x = self._old_state
            z = self._percept
            return [tuple(x.x) + tuple(z.x)]
        else:
            return []

    def reason_unknown(self) -> str:
        if self._gp_model.status == gp.GRB.INF_OR_UNBD:
            return '(model is infeasible or unbounded)'
        elif self._gp_model.status == gp.GRB.UNBOUNDED:
            return '(model is unbounded)'
        else:  # TODO other status code
            return '(gurobi model status code %d)' % self._gp_model.status