import numpy as np

from gem_stanley_teacher import GEMStanleyTeacher


class DTreeTeacherGEMStanley(GEMStanleyTeacher):
    def __init__(self, name="dtree_gem_stanley") -> None:
        super().__init__(name=name)

    def _set_candidate(self, candidate) -> None:
        a_mat, b_vec, coeff_mat, cut_vec = candidate  # TODO parse candidate
        # Variable Aliases
        m = self._gp_model
        x = self._old_state
        z = self._percept
        z_diff = self._percept_diff

        # Remove contraints from previous candidate first
        m.remove(self._prev_candidate_constr)
        self._prev_candidate_constr.clear()
        m.update()

        # Constraints on values of z
        cons = m.addConstr(z_diff == z - (a_mat @ x + b_vec))
        self._prev_candidate_constr.append(cons)

        # Constraints on values of z
        cons = m.addMConstr(coeff_mat, z_diff, '<', cut_vec)
        self._prev_candidate_constr.append(cons)


def test_dtree_gem_stanley_teacher():
    a_mat = np.array([[0., -1., 0.],
                      [0., 0., -1.]])
    b_vec = np.zeros(2)
    coeff_mat = np.array([
        [-1, -1],
        [1, -1],
        [-1, 0],
        [0, 1],
    ])
    cut_vec = np.array([1, 2, 3, 4])
    candidate = (a_mat, b_vec, coeff_mat, cut_vec)

    teacher = DTreeTeacherGEMStanley()
    teacher._set_candidate(candidate)
    teacher.dump_model()


if __name__ == "__main__":
    test_dtree_gem_stanley_teacher()