import numpy as np from gem_stanley_teacher import GEMStanleyTeacher class DTreeTeacherGEMStanley(GEMStanleyTeacher): def __init__(self, name="dtree_gem_stanley") -> None: super().__init__(name=name) def _set_candidate(self, candidate) -> None: a_mat, b_vec, coeff_mat, cut_vec = candidate # TODO parse candidate # Variable Aliases m = self._gp_model x = self._old_state z = self._percept z_diff = self._percept_diff # Remove contraints from previous candidate first m.remove(self._prev_candidate_constr) self._prev_candidate_constr.clear() m.update() # Constraints on values of z cons = m.addConstr(z_diff == z - (a_mat @ x + b_vec)) self._prev_candidate_constr.append(cons) # Constraints on values of z cons = m.addMConstr(coeff_mat, z_diff, '<', cut_vec) self._prev_candidate_constr.append(cons) def test_dtree_gem_stanley_teacher(): a_mat = np.array([[0., -1., 0.], [0., 0., -1.]]) b_vec = np.zeros(2) coeff_mat = np.array([ [-1, -1], [1, -1], [-1, 0], [0, 1], ]) cut_vec = np.array([1, 2, 3, 4]) candidate = (a_mat, b_vec, coeff_mat, cut_vec) teacher = DTreeTeacherGEMStanley() teacher._set_candidate(candidate) teacher.dump_model() if __name__ == "__main__": test_dtree_gem_stanley_teacher()