Skip to content
Snippets Groups Projects
Commit 8b74f63d authored by chsieh16's avatar chsieh16
Browse files

Reorder inheritance

parent 4bf0656f
No related branches found
No related tags found
No related merge requests found
...@@ -11,7 +11,7 @@ from typing import Dict, Hashable, List, Tuple ...@@ -11,7 +11,7 @@ from typing import Dict, Hashable, List, Tuple
import numpy as np import numpy as np
import z3 import z3
from dtree_learner import DTreeLearner as Learner from dtree_learner import DTreeLearner as Learner
from dtree_teacher_gem_stanley import DTreeGEMStanleyGurobiTeacher as Teacher from gem_stanley_teacher import DTreeGEMStanleyGurobiTeacher as Teacher
def load_examples(file_name: str, spec) -> Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]]]: def load_examples(file_name: str, spec) -> Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]]]:
...@@ -164,8 +164,8 @@ def main(dim:str, bnd_relax:float ): ...@@ -164,8 +164,8 @@ def main(dim:str, bnd_relax:float ):
result = [] result = []
for i, (part, (safe_dps, unsafe_dps, num_nan)) in enumerate(part_to_examples.items()): for i, (part, (safe_dps, unsafe_dps, num_nan)) in enumerate(part_to_examples.items()):
#if not i == 14: if not i == 1:
# continue continue
print("#"*80) print("#"*80)
print(f"# safe: {len(safe_dps)}; " print(f"# safe: {len(safe_dps)}; "
......
...@@ -3,10 +3,10 @@ import re ...@@ -3,10 +3,10 @@ import re
import gurobipy as gp import gurobipy as gp
import z3 import z3
from gem_stanley_teacher import GEMStanleyGurobiTeacher from teacher_base import GurobiTeacherBase
class DTreeGEMStanleyGurobiTeacher(GEMStanleyGurobiTeacher): class DTreeGurobiTeacherBase(GurobiTeacherBase):
PRECISION = 10**-3 PRECISION = 10**-3
Z3_VAR_RE = re.compile(r"(?P<var>\w+)_(?P<idx>\d+)") Z3_VAR_RE = re.compile(r"(?P<var>\w+)_(?P<idx>\d+)")
...@@ -153,18 +153,3 @@ class DTreeGEMStanleyGurobiTeacher(GEMStanleyGurobiTeacher): ...@@ -153,18 +153,3 @@ class DTreeGEMStanleyGurobiTeacher(GEMStanleyGurobiTeacher):
return z3.sat return z3.sat
else: else:
return z3.unsat return z3.unsat
def test_dtree_gem_stanley_gurobi_teacher():
x_0, x_1 = z3.Reals("x_0 x_1")
candidate = z3.If(x_0 >= 0.2,
z3.If(x_1 <= 1, z3.BoolVal(True), 0.5*x_0 + x_1 <= 3),
0.5*x_0 + x_1 > 3)
teacher = DTreeGEMStanleyGurobiTeacher(norm_ord=2)
print(teacher.check(candidate))
print(teacher.model())
if __name__ == "__main__":
test_dtree_gem_stanley_gurobi_teacher()
...@@ -6,7 +6,8 @@ import numpy as np ...@@ -6,7 +6,8 @@ import numpy as np
import sympy import sympy
import z3 import z3
from teacher_base import GurobiTeacherBase, DRealTeacherBase, SymPyTeacherBase from teacher_base import DRealTeacherBase, SymPyTeacherBase
from dtree_teacher_base import DTreeGurobiTeacherBase
WHEEL_BASE = 1.75 # m WHEEL_BASE = 1.75 # m
...@@ -79,7 +80,7 @@ class GEMStanleyDRealTeacher(DRealTeacherBase): ...@@ -79,7 +80,7 @@ class GEMStanleyDRealTeacher(DRealTeacherBase):
]) ])
class GEMStanleyGurobiTeacher(GurobiTeacherBase): class DTreeGEMStanleyGurobiTeacher(DTreeGurobiTeacherBase):
# Ideal perception as a linear transform from state to ground truth percept # Ideal perception as a linear transform from state to ground truth percept
PERC_GT = np.array([[0., -1., 0.], PERC_GT = np.array([[0., -1., 0.],
[0., 0., -1.]], float) [0., 0., -1.]], float)
...@@ -309,7 +310,7 @@ def test_gem_stanley_sympy_teacher(): ...@@ -309,7 +310,7 @@ def test_gem_stanley_sympy_teacher():
def test_gem_stanley_gurobi_teacher(): def test_gem_stanley_gurobi_teacher():
teacher = GEMStanleyGurobiTeacher(ultimate_bound=.4) teacher = DTreeGEMStanleyGurobiTeacher(ultimate_bound=.4)
teacher.set_old_state_bound( teacher.set_old_state_bound(
lb=(-np.inf, 0.5, 0.0625), lb=(-np.inf, 0.5, 0.0625),
ub=(np.inf, 1.0, 0.25) ub=(np.inf, 1.0, 0.25)
...@@ -317,6 +318,18 @@ def test_gem_stanley_gurobi_teacher(): ...@@ -317,6 +318,18 @@ def test_gem_stanley_gurobi_teacher():
teacher.dump_system_encoding() teacher.dump_system_encoding()
def test_dtree_gem_stanley_gurobi_teacher():
x_0, x_1 = z3.Reals("x_0 x_1")
candidate = z3.If(x_0 >= 0.2,
z3.If(x_1 <= 1, z3.BoolVal(True), 0.5*x_0 + x_1 <= 3),
0.5*x_0 + x_1 > 3)
teacher = DTreeGEMStanleyGurobiTeacher(norm_ord=2)
print(teacher.check(candidate))
print(teacher.model())
if __name__ == "__main__": if __name__ == "__main__":
test_gem_stanley_gurobi_teacher() test_gem_stanley_gurobi_teacher()
# test_gem_stanley_sympy_teacher() # test_gem_stanley_sympy_teacher()
test_dtree_gem_stanley_gurobi_teacher()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment