Skip to content
Snippets Groups Projects
Commit 68c86728 authored by chsieh16's avatar chsieh16
Browse files

Rename dump function

parent ba4170fd
No related branches found
No related tags found
No related merge requests found
......@@ -60,9 +60,9 @@ def test_firstpass_teacher():
teacher = FirstpassTeacher()
teacher.set_old_state_bound(lb=[6.0, 5.0], ub=[11.0, 10.0])
result = teacher.check(CAND0)
teacher.dump_model()
teacher.dump_system_encoding()
result = teacher.check(CAND1)
teacher.dump_model("firstpass2")
teacher.dump_system_encoding("firstpass2")
print(result)
if result == z3.sat:
print(teacher.model())
......
......@@ -225,17 +225,9 @@ class GEMStanleySymPyTeacher(SymPyTeacherBase):
old_x, old_y, old_yaw = self._old_state
new_x, new_y, new_yaw = self._new_state
if self._norm_ord == 1:
old_err = sympy.Abs(old_y) + sympy.Abs(old_yaw)
new_err = sympy.Abs(new_y) + sympy.Abs(new_yaw)
elif self._norm_ord == 2:
old_err = sympy.sqrt(old_y**2 + old_yaw**2)
new_err = sympy.sqrt(new_y**2 + new_yaw**2)
else:
assert self._norm_ord == "inf"
old_err = sympy.Max(sympy.Abs(old_y), sympy.Abs(old_yaw))
new_err = sympy.Max(sympy.Abs(new_y), sympy.Abs(new_yaw))
norm_ord = sympy.oo if self._norm_ord == "inf" else self._norm_ord
old_err = sympy.Matrix([old_y, old_yaw]).norm(ord=norm_ord)
new_err = sympy.Matrix([new_y, new_yaw]).norm(ord=norm_ord)
self._not_inv_cons.extend([
old_err <= new_err
])
......@@ -269,7 +261,7 @@ def test_gem_stanley_sympy_teacher():
lb=(-sympy.oo, sympy.Rational("0.5"), 0.0),
ub=(sympy.oo, sympy.Rational("1.2"), sympy.pi/12)
)
teacher.dump_system_formula()
teacher.dump_system_encoding()
def test_gem_stanley_gurobi_teacher():
......@@ -278,7 +270,7 @@ def test_gem_stanley_gurobi_teacher():
lb=(-np.inf, 0.5, 0.0625),
ub=(np.inf, 1.0, 0.25)
)
teacher.dump_model()
teacher.dump_system_encoding()
print(teacher.check(None))
print(teacher.model())
......
......@@ -25,6 +25,10 @@ class TeacherBase(abc.ABC):
def reason_unknown(self) -> str:
raise NotImplementedError
@abc.abstractmethod
def dump_system_encoding(self, basename: str = "") -> None:
raise NotImplementedError
class DRealTeacherBase(TeacherBase):
@staticmethod
......@@ -125,7 +129,7 @@ class DRealTeacherBase(TeacherBase):
else:
return z3.unsat
def dump_model(self, basename: str = "") -> None:
def dump_system_encoding(self, basename: str = "") -> None:
print([
dreal.And(lb_i <= var_i, var_i <= ub_i)
for var_i, (lb_i, ub_i) in self._var_bounds.items()
......@@ -241,6 +245,7 @@ class GurobiTeacherBase(TeacherBase):
self._old_state.ub = ub
def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
raise NotImplementedError("TODO convert from sympy")
self._set_candidate(candidate)
self._gp_model.optimize()
if self._gp_model.status == gp.GRB.INF_OR_UNBD:
......@@ -254,7 +259,7 @@ class GurobiTeacherBase(TeacherBase):
else:
return z3.unknown
def dump_model(self, basename: str = "") -> None:
def dump_system_encoding(self, basename: str = "") -> None:
""" Dump optimization problem in LP format """
if not basename:
basename = self._gp_model.ModelName
......@@ -283,13 +288,13 @@ class SymPyTeacherBase(TeacherBase):
self._norm_ord = norm_ord
# Old state variables
self._old_state = sympy.symbols(names=[f"x[{i}]" for i in range(state_dim)], real=True)
self._old_state = sympy.symbols(names=[f"x_{i}" for i in range(state_dim)], real=True)
# New state variables
self._new_state = sympy.symbols(names=[f"x'[{i}]" for i in range(state_dim)], real=True)
self._new_state = sympy.symbols(names=[f"x'_{i}" for i in range(state_dim)], real=True)
# Perception variables
self._percept = sympy.symbols(names=[f"z[{i}]" for i in range(perc_dim)], real=True)
self._percept = sympy.symbols(names=[f"z_{i}" for i in range(perc_dim)], real=True)
# Control variables
self._control = sympy.symbols(names=[f"u[{i}]" for i in range(ctrl_dim)], real=True)
self._control = sympy.symbols(names=[f"u_{i}" for i in range(ctrl_dim)], real=True)
self._var_bounds = {
var: (-sympy.oo, sympy.oo)
......@@ -333,7 +338,7 @@ class SymPyTeacherBase(TeacherBase):
def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
raise NotImplementedError
def dump_system_formula(self, basename: str = "") -> None:
def dump_system_encoding(self, basename: str = "") -> None:
query = sympy.And(
*(sympy.And(lb_i <= var_i, var_i <= ub_i)
for var_i, (lb_i, ub_i) in self._var_bounds.items()),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment