From 68c8672839ea18a8d5f7a56da0b7e655ff030231 Mon Sep 17 00:00:00 2001
From: "Hsieh, Chiao" <chsieh16@illinois.edu>
Date: Tue, 19 Apr 2022 10:12:59 -0500
Subject: [PATCH] Rename dump function

---
 firstpass_teacher.py   |  4 ++--
 gem_stanley_teacher.py | 18 +++++-------------
 teacher_base.py        | 19 ++++++++++++-------
 3 files changed, 19 insertions(+), 22 deletions(-)

diff --git a/firstpass_teacher.py b/firstpass_teacher.py
index 4de5124..2d1088c 100755
--- a/firstpass_teacher.py
+++ b/firstpass_teacher.py
@@ -60,9 +60,9 @@ def test_firstpass_teacher():
     teacher = FirstpassTeacher()
     teacher.set_old_state_bound(lb=[6.0, 5.0], ub=[11.0, 10.0])
     result = teacher.check(CAND0)
-    teacher.dump_model()
+    teacher.dump_system_encoding()
     result = teacher.check(CAND1)
-    teacher.dump_model("firstpass2")
+    teacher.dump_system_encoding("firstpass2")
     print(result)
     if result == z3.sat:
         print(teacher.model())
diff --git a/gem_stanley_teacher.py b/gem_stanley_teacher.py
index 59c8ef0..14a20bf 100644
--- a/gem_stanley_teacher.py
+++ b/gem_stanley_teacher.py
@@ -225,17 +225,9 @@ class GEMStanleySymPyTeacher(SymPyTeacherBase):
         old_x, old_y, old_yaw = self._old_state
         new_x, new_y, new_yaw = self._new_state
 
-        if self._norm_ord == 1:
-            old_err = sympy.Abs(old_y) + sympy.Abs(old_yaw)
-            new_err = sympy.Abs(new_y) + sympy.Abs(new_yaw)
-        elif self._norm_ord == 2:
-            old_err = sympy.sqrt(old_y**2 + old_yaw**2)
-            new_err = sympy.sqrt(new_y**2 + new_yaw**2)
-        else:
-            assert self._norm_ord == "inf"
-            old_err = sympy.Max(sympy.Abs(old_y), sympy.Abs(old_yaw))
-            new_err = sympy.Max(sympy.Abs(new_y), sympy.Abs(new_yaw))
-
+        norm_ord = sympy.oo if self._norm_ord == "inf" else self._norm_ord
+        old_err = sympy.Matrix([old_y, old_yaw]).norm(ord=norm_ord)
+        new_err = sympy.Matrix([new_y, new_yaw]).norm(ord=norm_ord)
         self._not_inv_cons.extend([
             old_err <= new_err
         ])
@@ -269,7 +261,7 @@ def test_gem_stanley_sympy_teacher():
         lb=(-sympy.oo, sympy.Rational("0.5"), 0.0),
         ub=(sympy.oo, sympy.Rational("1.2"), sympy.pi/12)
     )
-    teacher.dump_system_formula()
+    teacher.dump_system_encoding()
 
 
 def test_gem_stanley_gurobi_teacher():
@@ -278,7 +270,7 @@ def test_gem_stanley_gurobi_teacher():
         lb=(-np.inf, 0.5, 0.0625),
         ub=(np.inf, 1.0, 0.25)
     )
-    teacher.dump_model()
+    teacher.dump_system_encoding()
     print(teacher.check(None))
     print(teacher.model())
 
diff --git a/teacher_base.py b/teacher_base.py
index 2645d85..21512f6 100644
--- a/teacher_base.py
+++ b/teacher_base.py
@@ -25,6 +25,10 @@ class TeacherBase(abc.ABC):
     def reason_unknown(self) -> str:
         raise NotImplementedError
 
+    @abc.abstractmethod
+    def dump_system_encoding(self, basename: str = "") -> None:
+        raise NotImplementedError
+
 
 class DRealTeacherBase(TeacherBase):
     @staticmethod
@@ -125,7 +129,7 @@ class DRealTeacherBase(TeacherBase):
         else:
             return z3.unsat
 
-    def dump_model(self, basename: str = "") -> None:
+    def dump_system_encoding(self, basename: str = "") -> None:
         print([
             dreal.And(lb_i <= var_i, var_i <= ub_i)
             for var_i, (lb_i, ub_i) in self._var_bounds.items()
@@ -241,6 +245,7 @@ class GurobiTeacherBase(TeacherBase):
         self._old_state.ub = ub
 
     def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
+        raise NotImplementedError("TODO convert from sympy")
         self._set_candidate(candidate)
         self._gp_model.optimize()
         if self._gp_model.status == gp.GRB.INF_OR_UNBD:
@@ -254,7 +259,7 @@ class GurobiTeacherBase(TeacherBase):
         else:
             return z3.unknown
 
-    def dump_model(self, basename: str = "") -> None:
+    def dump_system_encoding(self, basename: str = "") -> None:
         """ Dump optimization problem in LP format """
         if not basename:
             basename = self._gp_model.ModelName
@@ -283,13 +288,13 @@ class SymPyTeacherBase(TeacherBase):
         self._norm_ord = norm_ord
 
         # Old state variables
-        self._old_state = sympy.symbols(names=[f"x[{i}]" for i in range(state_dim)], real=True)
+        self._old_state = sympy.symbols(names=[f"x_{i}" for i in range(state_dim)], real=True)
         # New state variables
-        self._new_state = sympy.symbols(names=[f"x'[{i}]" for i in range(state_dim)], real=True)
+        self._new_state = sympy.symbols(names=[f"x'_{i}" for i in range(state_dim)], real=True)
         # Perception variables
-        self._percept = sympy.symbols(names=[f"z[{i}]" for i in range(perc_dim)], real=True)
+        self._percept = sympy.symbols(names=[f"z_{i}" for i in range(perc_dim)], real=True)
         # Control variables
-        self._control = sympy.symbols(names=[f"u[{i}]" for i in range(ctrl_dim)], real=True)
+        self._control = sympy.symbols(names=[f"u_{i}" for i in range(ctrl_dim)], real=True)
 
         self._var_bounds = {
             var: (-sympy.oo, sympy.oo)
@@ -333,7 +338,7 @@ class SymPyTeacherBase(TeacherBase):
     def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
         raise NotImplementedError
 
-    def dump_system_formula(self, basename: str = "") -> None:
+    def dump_system_encoding(self, basename: str = "") -> None:
         query = sympy.And(
             *(sympy.And(lb_i <= var_i, var_i <= ub_i)
               for var_i, (lb_i, ub_i) in self._var_bounds.items()),
-- 
GitLab