From 68c8672839ea18a8d5f7a56da0b7e655ff030231 Mon Sep 17 00:00:00 2001 From: "Hsieh, Chiao" <chsieh16@illinois.edu> Date: Tue, 19 Apr 2022 10:12:59 -0500 Subject: [PATCH] Rename dump function --- firstpass_teacher.py | 4 ++-- gem_stanley_teacher.py | 18 +++++------------- teacher_base.py | 19 ++++++++++++------- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/firstpass_teacher.py b/firstpass_teacher.py index 4de5124..2d1088c 100755 --- a/firstpass_teacher.py +++ b/firstpass_teacher.py @@ -60,9 +60,9 @@ def test_firstpass_teacher(): teacher = FirstpassTeacher() teacher.set_old_state_bound(lb=[6.0, 5.0], ub=[11.0, 10.0]) result = teacher.check(CAND0) - teacher.dump_model() + teacher.dump_system_encoding() result = teacher.check(CAND1) - teacher.dump_model("firstpass2") + teacher.dump_system_encoding("firstpass2") print(result) if result == z3.sat: print(teacher.model()) diff --git a/gem_stanley_teacher.py b/gem_stanley_teacher.py index 59c8ef0..14a20bf 100644 --- a/gem_stanley_teacher.py +++ b/gem_stanley_teacher.py @@ -225,17 +225,9 @@ class GEMStanleySymPyTeacher(SymPyTeacherBase): old_x, old_y, old_yaw = self._old_state new_x, new_y, new_yaw = self._new_state - if self._norm_ord == 1: - old_err = sympy.Abs(old_y) + sympy.Abs(old_yaw) - new_err = sympy.Abs(new_y) + sympy.Abs(new_yaw) - elif self._norm_ord == 2: - old_err = sympy.sqrt(old_y**2 + old_yaw**2) - new_err = sympy.sqrt(new_y**2 + new_yaw**2) - else: - assert self._norm_ord == "inf" - old_err = sympy.Max(sympy.Abs(old_y), sympy.Abs(old_yaw)) - new_err = sympy.Max(sympy.Abs(new_y), sympy.Abs(new_yaw)) - + norm_ord = sympy.oo if self._norm_ord == "inf" else self._norm_ord + old_err = sympy.Matrix([old_y, old_yaw]).norm(ord=norm_ord) + new_err = sympy.Matrix([new_y, new_yaw]).norm(ord=norm_ord) self._not_inv_cons.extend([ old_err <= new_err ]) @@ -269,7 +261,7 @@ def test_gem_stanley_sympy_teacher(): lb=(-sympy.oo, sympy.Rational("0.5"), 0.0), ub=(sympy.oo, sympy.Rational("1.2"), sympy.pi/12) ) - teacher.dump_system_formula() + teacher.dump_system_encoding() def test_gem_stanley_gurobi_teacher(): @@ -278,7 +270,7 @@ def test_gem_stanley_gurobi_teacher(): lb=(-np.inf, 0.5, 0.0625), ub=(np.inf, 1.0, 0.25) ) - teacher.dump_model() + teacher.dump_system_encoding() print(teacher.check(None)) print(teacher.model()) diff --git a/teacher_base.py b/teacher_base.py index 2645d85..21512f6 100644 --- a/teacher_base.py +++ b/teacher_base.py @@ -25,6 +25,10 @@ class TeacherBase(abc.ABC): def reason_unknown(self) -> str: raise NotImplementedError + @abc.abstractmethod + def dump_system_encoding(self, basename: str = "") -> None: + raise NotImplementedError + class DRealTeacherBase(TeacherBase): @staticmethod @@ -125,7 +129,7 @@ class DRealTeacherBase(TeacherBase): else: return z3.unsat - def dump_model(self, basename: str = "") -> None: + def dump_system_encoding(self, basename: str = "") -> None: print([ dreal.And(lb_i <= var_i, var_i <= ub_i) for var_i, (lb_i, ub_i) in self._var_bounds.items() @@ -241,6 +245,7 @@ class GurobiTeacherBase(TeacherBase): self._old_state.ub = ub def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult: + raise NotImplementedError("TODO convert from sympy") self._set_candidate(candidate) self._gp_model.optimize() if self._gp_model.status == gp.GRB.INF_OR_UNBD: @@ -254,7 +259,7 @@ class GurobiTeacherBase(TeacherBase): else: return z3.unknown - def dump_model(self, basename: str = "") -> None: + def dump_system_encoding(self, basename: str = "") -> None: """ Dump optimization problem in LP format """ if not basename: basename = self._gp_model.ModelName @@ -283,13 +288,13 @@ class SymPyTeacherBase(TeacherBase): self._norm_ord = norm_ord # Old state variables - self._old_state = sympy.symbols(names=[f"x[{i}]" for i in range(state_dim)], real=True) + self._old_state = sympy.symbols(names=[f"x_{i}" for i in range(state_dim)], real=True) # New state variables - self._new_state = sympy.symbols(names=[f"x'[{i}]" for i in range(state_dim)], real=True) + self._new_state = sympy.symbols(names=[f"x'_{i}" for i in range(state_dim)], real=True) # Perception variables - self._percept = sympy.symbols(names=[f"z[{i}]" for i in range(perc_dim)], real=True) + self._percept = sympy.symbols(names=[f"z_{i}" for i in range(perc_dim)], real=True) # Control variables - self._control = sympy.symbols(names=[f"u[{i}]" for i in range(ctrl_dim)], real=True) + self._control = sympy.symbols(names=[f"u_{i}" for i in range(ctrl_dim)], real=True) self._var_bounds = { var: (-sympy.oo, sympy.oo) @@ -333,7 +338,7 @@ class SymPyTeacherBase(TeacherBase): def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult: raise NotImplementedError - def dump_system_formula(self, basename: str = "") -> None: + def dump_system_encoding(self, basename: str = "") -> None: query = sympy.And( *(sympy.And(lb_i <= var_i, var_i <= ub_i) for var_i, (lb_i, ub_i) in self._var_bounds.items()), -- GitLab