import re

import gurobipy as gp
import z3

from teacher_base import GurobiTeacherBase


class DTreeGurobiTeacherBase(GurobiTeacherBase):
    PRECISION = 10**-3
    Z3_VAR_RE = re.compile(r"(?P<var>\w+)_(?P<idx>\d+)")

    def _build_affine_expr(self, z3_expr: z3.ExprRef):
        if z3.is_rational_value(z3_expr):
            return z3_expr.as_fraction()
        elif z3.is_var(z3_expr) or z3.is_const(z3_expr):
            result = self.Z3_VAR_RE.search(str(z3_expr))
            assert result is not None, str(z3_expr)
            var_name, idx = result.group("var"), result.group("idx")
            gp_var = self._gp_model.getVarByName(f"{var_name}[{idx}]")
            assert gp_var is not None
            return gp_var
        elif z3.is_add(z3_expr):
            return sum(self._build_affine_expr(arg) for arg in z3_expr.children())
        elif z3.is_mul(z3_expr):
            if len(z3_expr.children()) > 2:
                raise NotImplementedError("TODO: multiplication of three or more operands")
            lhs = self._build_affine_expr(z3_expr.arg(0))
            rhs = self._build_affine_expr(z3_expr.arg(1))
            return lhs * rhs
        raise RuntimeError(f"Only support affine expressions. {z3_expr.children()}")

    def _set_candidate(self, conjunct: z3.BoolRef) -> None:
        # Variable Aliases
        m = self._gp_model

        # Remove contraints from previous candidate first
        m.remove(self._prev_candidate_constr)
        self._prev_candidate_constr.clear()
        m.update()

        conjunct = z3.simplify(conjunct, flat=True, arith_lhs=True)
        if z3.is_true(conjunct):
            return
        elif z3.is_and(conjunct):
            pred_list = list(conjunct.children())
        elif z3.is_eq(conjunct) or z3.is_le(conjunct) or z3.is_ge(conjunct) or z3.is_not(conjunct):
            pred_list = [conjunct]
        else:
            raise RuntimeError(f"{conjunct} should be a conjunction.")

        for orig_pred in pred_list:
            if z3.is_not(orig_pred):
                pred = orig_pred.arg(0)
            else:
                pred = orig_pred

            assert z3.is_eq(pred) or z3.is_le(pred) or z3.is_ge(pred), str(pred)
            lhs = self._build_affine_expr(pred.arg(0))
            rhs = self._build_affine_expr(pred.arg(1))

            if z3.is_eq(pred):
                cons = (lhs == rhs)
            elif z3.is_ge(pred):
                if not z3.is_not(orig_pred):
                    cons = (lhs >= rhs)
                else:  # !(lhs >= rhs) <=> (lhs < rhs) => lhs <= rhs - ð
                    cons = (lhs <= rhs - self.PRECISION)
            elif z3.is_le(pred):
                if not z3.is_not(orig_pred):
                    cons = (lhs <= rhs)
                else:  # !(lhs <= rhs) <=> (lhs > rhs) => lhs >= rhs + ð
                    cons = (lhs >= rhs + self.PRECISION)
            else:
                raise RuntimeError(f"Unsupported atomic predicate expression {pred}")

            gp_cons = self._gp_model.addConstr(cons)
            self._prev_candidate_constr.append(gp_cons)

    def _candidate_to_conjuncts(self, candidate: z3.BoolRef):
        init_path = [candidate]
        stack = [init_path]
        while stack:
            curr_path = stack.pop()  # remove this path
            curr_node = curr_path.pop()  # remove last node in this path
            if z3.is_false(curr_node):
                # the leaf node in this path is false. Skip
                continue
            elif z3.is_true(curr_node):
                if not curr_path:
                    yield z3.BoolVal(True)
                else:
                    yield z3.And(*curr_path)
            elif z3.is_gt(curr_node) or z3.is_ge(curr_node) \
                    or z3.is_lt(curr_node) or z3.is_le(curr_node):
                yield z3.And(*curr_path, curr_node)
            elif z3.is_app_of(curr_node, z3.Z3_OP_ITE):
                cond, left, right = curr_node.children()
                l_path = curr_path.copy()
                l_path.extend([cond, left])

                r_path = curr_path.copy()
                assert len(cond.children()) == 2
                lhs, rhs = cond.children()
                if z3.is_le(cond):
                    not_cond = lhs > rhs
                elif z3.is_ge(cond):
                    not_cond = lhs < rhs
                else:
                    raise RuntimeError(f"Unexpected condition {cond} for ITE")
                r_path.extend([not_cond, right])

                stack.append(r_path)
                stack.append(l_path)
            else:
                raise RuntimeError(f"Candidate formula {curr_node} should have been converted to DNF.")

    def check(self, candidate: z3.BoolRef) -> z3.CheckSatResult:
        self._cexs.clear()
        conjunct_iter = self._candidate_to_conjuncts(candidate)
        #TODO ANGELLO: add check if too many conjunctions
        print(f"number of conjunct:  {len(list(self._candidate_to_conjuncts(candidate)))}")
        print("Checking candidate", flush=True)
        for conjunct in conjunct_iter:
            # print(".", end='', flush=True)
            self._set_candidate(conjunct)
            self._gp_model.optimize()
            if self._gp_model.status == gp.GRB.INF_OR_UNBD:
                self._gp_model.setParam("DualReductions", 0)
                self._gp_model.optimize()

            if self._gp_model.status in [gp.GRB.OPTIMAL, gp.GRB.SUBOPTIMAL]:
                cex_list = []
                for i in range(self._gp_model.SolCount):
                    self._gp_model.Params.SolutionNumber = i
                    cex = tuple(self._old_state.Xn) + tuple(self._percept.Xn)
                    cex_list.append(cex)
                filtered_cex_list = [cex for cex in cex_list
                                     if not self.is_spurious_example(self.state_dim, self.perc_dim, conjunct, cex)]
                if filtered_cex_list:
                    self._cexs.extend(filtered_cex_list)
                else:
                    # raise RuntimeError(f"Only found spurious cexs {cex_list} for the conjunct {conjunct}.")
                    pass
            elif self._gp_model.status == gp.GRB.INFEASIBLE:
                continue
            elif self._gp_model.status == gp.GRB.INTERRUPTED:
                raise KeyboardInterrupt
            else:
                return z3.unknown
        print("Done")
        if len(self._cexs) > 0:
            return z3.sat
        else:
            return z3.unsat