Skip to content
Snippets Groups Projects
Commit 933d64c4 authored by chsieh16's avatar chsieh16
Browse files

Use Z3 instead of sympy for simplifying formula

parent a2f19d65
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,7 @@ import os
from typing import Any, Dict, List, MutableSet, Tuple
import numpy as np
import sympy
import z3
from learner_base import LearnerBase
......@@ -150,7 +150,7 @@ class DTreeLearner(LearnerBase):
row = itertools.chain(f, [label]) # append label at the end of each row
data_out.writerow(row)
def learn(self) -> sympy.logic.boolalg.Boolean:
def learn(self) -> z3.BoolRef:
res = os.popen(self.exec).read()
assert os.path.exists(self.tree_out), "if learned successfully" \
f"there should be a json file in {self.dir_name}"
......@@ -161,43 +161,52 @@ class DTreeLearner(LearnerBase):
ite_expr = self._subs_basevar_w_states(ite_expr)
return ite_expr
def _subs_basevar_w_states(self, ite_expr) -> sympy.logic.boolalg.Boolean:
state_vars = sympy.symbols([f"x_{i}" for i in range(self.state_dim)])
state_vec = sympy.Matrix(state_vars)
perc_vars = sympy.symbols([f"z_{i}" for i in range(self.perc_dim)])
perc_vec = sympy.Matrix(perc_vars)
def _subs_basevar_w_states(self, ite_expr) -> z3.BoolRef:
state_vars = z3.Reals([f"x_{i}" for i in range(self.state_dim)])
perc_vars = z3.Reals([f"z_{i}" for i in range(self.perc_dim)])
subs_basevar = []
for basevar, (trans, j) in self._basevar_trans_map.items():
a_mat, b_vec = trans
a_mat, b_vec = sympy.Matrix(a_mat), sympy.Matrix(b_vec)
expanded_basevar = (perc_vec - (a_mat @ state_vec + b_vec))[j]
subs_basevar.append((basevar, expanded_basevar))
return ite_expr.subs(subs_basevar)
expanded_basevar = perc_vars[j] - ((a_mat @ state_vars)[j] + b_vec[j])
expanded_basevar = z3.simplify(expanded_basevar)
subs_basevar.append((z3.Real(basevar), expanded_basevar))
return z3.substitute(ite_expr, *subs_basevar)
@staticmethod
def get_pre_from_json(path):
def get_pre_from_json(self, path):
try:
with open(path) as json_file:
tree = json.load(json_file)
return DTreeLearner.parse_tree(tree)
return self.parse_tree(tree)
except json.JSONDecodeError:
raise ValueError(f"cannot parse {path} as a json file")
@staticmethod
def parse_tree(tree) -> sympy.logic.boolalg.Boolean:
def parse_tree(self, tree) -> z3.BoolRef:
if tree['children'] is None:
# At a leaf node, return the clause
if tree['classification']:
return sympy.true # True leaf node
return z3.BoolVal(True) # True leaf node
else:
return sympy.false # False leaf node
return z3.BoolVal(False) # False leaf node
elif len(tree['children']) == 2:
# Post-order traversal
left = DTreeLearner.parse_tree(tree['children'][0])
right = DTreeLearner.parse_tree(tree['children'][1])
left = self.parse_tree(tree['children'][0])
right = self.parse_tree(tree['children'][1])
# Create an ITE expression tree
cond = sympy.sympify(f"{tree['attribute']} <= {tree['cut']}")
return sympy.logic.boolalg.ITE(cond, left, right)
z3_expr = z3.Sum(*(coeff*z3.Real(base_fvar) for base_fvar, coeff
in self._var_coeff_map[tree['attribute']].items()))
z3_cut = z3.simplify(z3.fpToReal(z3.FPVal(tree['cut'], z3.Float64())))
if z3.is_true(left):
if z3.is_true(right):
return z3.BoolVal(True)
elif z3.is_false(right):
return (z3_expr <= z3_cut)
if z3.is_false(left):
if z3.is_true(right):
return (z3_expr > z3_cut)
elif z3.is_false(right):
return z3.BoolVal(False)
# else:
return z3.If((z3_expr <= z3_cut), left, right)
else:
raise ValueError("error parsing the json object as a binary decision tree)")
......@@ -241,7 +250,7 @@ def test_dtree_learner():
]
learner.add_negative_examples(*neg_examples)
print(learner.learn())
print("Learned ITE expression:", learner.learn())
def test_sample_to_feature():
......@@ -269,20 +278,18 @@ def test_sample_to_feature():
def test_parse_json():
json_obj = json.loads("""
{"attribute":"((1*fvar0_A0) + (1*fvar1_A0))","cut":-0.01318414,"classification":0,
{"attribute":"((1*fvar0_A0) + (1*fvar1_A0))","cut":-0.01,"classification":0,
"children":[
{"attribute":"fvar1_A0","cut":0.01403625,"classification":0,
{"attribute":"fvar1_A0","cut":0.625,"classification":0,
"children":[{"attribute":"","cut":0,"classification":true,"children":null},
{"attribute":"","cut":0,"classification":false,"children":null}]
},
{"attribute":"fvar1_A1","cut":-0.003193465,"classification":0,
{"attribute":"fvar1_A1","cut":-0.15,"classification":0,
"children":[{"attribute":"","cut":0,"classification":true,"children":null},
{"attribute":"","cut":0,"classification":false,"children":null}]
}
]
}""")
tree = DTreeLearner.parse_tree(json_obj)
a_mat_0 = np.array([[0., -1., 0.],
[0., 0., -1.]])
b_vec_0 = np.zeros(2)
......@@ -293,6 +300,7 @@ def test_parse_json():
learner = DTreeLearner(state_dim=3, perc_dim=2)
learner.set_grammar([(a_mat_0, b_vec_0), (a_mat_1, b_vec_1)])
tree = learner.parse_tree(json_obj)
print(learner._subs_basevar_w_states(tree))
......
......@@ -3,52 +3,57 @@ from typing import List, Tuple
import numpy as np
import z3
import sympy
from dtree_learner import DTreeLearner as Learner
from dtree_teacher_gem_stanley import DTreeGEMStanleyGurobiTeacher as Teacher
def load_positive_examples(file_name: str) -> List[Tuple[float, ...]]:
def load_examples(file_name: str, spec) -> Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]]]:
print("Loading examples")
with open(file_name, "rb") as pickle_file_io:
pkl_data = pickle.load(pickle_file_io)
truth_samples_seq = pkl_data["truth_samples"]
i_th = 0 # select only the i-th partition
truth_samples_seq = truth_samples_seq[i_th:i_th+1]
print("Representative point in partition:", truth_samples_seq[0][0])
truth_samples_seq = [(t, [s for s in raw_samples if not any(np.isnan(s))])
for t, raw_samples in truth_samples_seq]
# Convert from sampled states and percepts to positive examples for learning
return [
s for _, samples in truth_samples_seq for s in samples
]
# Convert from sampled states and percepts to positive and negative examples for learning
pos_exs, neg_exs, num_excl_exs = [], [], 0
for _, ss in truth_samples_seq:
for s in ss:
ret = spec(s)
if np.any(np.isnan(s)) or ret is None:
num_excl_exs += 1
elif ret:
pos_exs.append(s)
else:
neg_exs.append(s)
print("# Exculded examples:", num_excl_exs)
return pos_exs, neg_exs
def test_synth_dtree():
positive_examples = load_positive_examples(
"data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle")
# positive_examples = positive_examples[:20:] # Select only first few examples
teacher = Teacher(norm_ord=1)
# 0.0 <= x <= 32.0 and -1.0 <= y <= -0.9 and 0.2 <= theta <= 0.22
teacher.set_old_state_bound(lb=[0.0, -1.0, 0.2], ub=[32.0, -0.9, 0.22])
# teacher.set_old_state_bound(lb=[0.0, -0.9, 2*np.pi/60], ub=[32.0, -0.6, 3*np.pi/60])
# teacher.set_old_state_bound(lb=[0.0, 0.3, 1*np.pi/60], ub=[32.0, 0.9, 5*np.pi/60])
positive_examples, negative_examples = load_examples(
# "data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle",
"data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle",
teacher.is_positive_example)
print("# positive examples: %d" % len(positive_examples))
print("# negative examples: %d" % len(negative_examples))
ex_dim = len(positive_examples[0])
print("#examples: %d" % len(positive_examples))
print("Dimension of each example: %d" % ex_dim)
assert all(len(ex) == ex_dim and not any(np.isnan(ex))
for ex in positive_examples)
teacher = Teacher()
assert teacher.state_dim + teacher.perc_dim == ex_dim
# 0.0 <= x <= 30.0 and -1.0 <= y <= 0.9 and 0.2 <= theta <= 0.22
teacher.set_old_state_bound(lb=[0.0, -1.0, 0.2], ub=[30.0, -0.9, 0.22])
synth_dtree(positive_examples, teacher, num_max_iterations=2000)
synth_dtree(positive_examples, negative_examples, teacher, num_max_iterations=2000)
def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
def synth_dtree(positive_examples, negative_examples, teacher, num_max_iterations: int = 10):
learner = Learner(state_dim=teacher.state_dim,
perc_dim=teacher.perc_dim, timeout=20000)
......@@ -61,6 +66,7 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
learner.set_grammar([(a_mat_0, b_vec_0)])
learner.add_positive_examples(*positive_examples)
learner.add_negative_examples(*negative_examples)
past_candidate_list = []
for k in range(num_max_iterations):
......@@ -79,15 +85,16 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
print(f"Satisfiability: {result}")
if result == z3.sat:
negative_examples = teacher.model()
assert len(negative_examples) > 0
# assert len(negative_examples) > 0
print(f"negative examples: {negative_examples}")
assert validate_cexs(teacher.state_dim, teacher.perc_dim, candidate, negative_examples)
# assert validate_cexs(teacher.state_dim, teacher.perc_dim, candidate, negative_examples)
learner.add_negative_examples(*negative_examples)
continue
elif result == z3.unsat:
print("we are done!")
print(f"Simplified candidate: {z3.simplify(candidate)}")
return past_candidate_list
else:
print("Reason Unknown", teacher.reason_unknown())
......@@ -99,17 +106,10 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
def validate_cexs(state_dim: int, perc_dim: int,
candidate: sympy.logic.boolalg.Boolean,
candidate: z3.BoolRef,
cexs: List[Tuple[float]]) -> bool:
spurious_cexs = []
for cex in cexs:
state_subs_map = [(f"x_{i}", cex[i]) for i in range(state_dim)]
perc_subs_map = [(f"z_{i}", cex[i+state_dim]) for i in range(perc_dim)]
sub_map = state_subs_map + perc_subs_map
val = candidate.subs(sub_map)
assert isinstance(val, sympy.logic.boolalg.BooleanAtom)
if val == sympy.false:
spurious_cexs.append(cex)
spurious_cexs = [cex for cex in cexs
if Teacher.is_spurious_example(state_dim, perc_dim, candidate, cex)]
if not spurious_cexs:
return True
else:
......
import re
import gurobipy as gp
import numpy as np
import sympy
import z3
from gem_stanley_teacher import GEMStanleyGurobiTeacher
......@@ -10,30 +8,29 @@ from gem_stanley_teacher import GEMStanleyGurobiTeacher
class DTreeGEMStanleyGurobiTeacher(GEMStanleyGurobiTeacher):
PRECISION = 10**-3
SYMPY_VAR_RE = re.compile(r"(?P<var>\w+)_(?P<idx>\d+)")
def _build_affine_expr(self, sympy_expr: sympy.Expr):
if isinstance(sympy_expr, sympy.Symbol):
result = self.SYMPY_VAR_RE.search(sympy_expr.name)
assert result is not None
Z3_VAR_RE = re.compile(r"(?P<var>\w+)_(?P<idx>\d+)")
def _build_affine_expr(self, z3_expr: z3.ExprRef):
if z3.is_rational_value(z3_expr):
return z3_expr.as_fraction()
elif z3.is_var(z3_expr) or z3.is_const(z3_expr):
result = self.Z3_VAR_RE.search(str(z3_expr))
assert result is not None, str(z3_expr)
var_name, idx = result.group("var"), result.group("idx")
gp_var = self._gp_model.getVarByName(f"{var_name}[{idx}]")
assert gp_var is not None
return gp_var
elif isinstance(sympy_expr, sympy.Number):
return float(sympy_expr)
elif isinstance(sympy_expr, sympy.Add):
return sum(self._build_affine_expr(arg) for arg in sympy_expr.args)
elif isinstance(sympy_expr, sympy.Mul):
if len(sympy_expr.args) > 2:
elif z3.is_add(z3_expr):
return sum(self._build_affine_expr(arg) for arg in z3_expr.children())
elif z3.is_mul(z3_expr):
if len(z3_expr.children()) > 2:
raise NotImplementedError("TODO: multiplication of three or more operands")
lhs = self._build_affine_expr(sympy_expr.args[0])
rhs = self._build_affine_expr(sympy_expr.args[1])
lhs = self._build_affine_expr(z3_expr.arg(0))
rhs = self._build_affine_expr(z3_expr.arg(1))
return lhs * rhs
else:
raise RuntimeError("Only support affine expressions.")
raise RuntimeError(f"Only support affine expressions. {z3_expr.children()}")
def _set_candidate(self, conjunct: sympy.logic.boolalg.Boolean) -> None:
def _set_candidate(self, conjunct: z3.BoolRef) -> None:
# Variable Aliases
m = self._gp_model
......@@ -42,51 +39,88 @@ class DTreeGEMStanleyGurobiTeacher(GEMStanleyGurobiTeacher):
self._prev_candidate_constr.clear()
m.update()
if conjunct is sympy.true:
conjunct = z3.simplify(conjunct, flat=True, arith_lhs=True)
if z3.is_true(conjunct):
return
elif isinstance(conjunct, sympy.And):
pred_list = conjunct.args
elif isinstance(conjunct, sympy.core.relational.Relational):
elif z3.is_and(conjunct):
pred_list = list(conjunct.children())
elif z3.is_eq(conjunct) or z3.is_le(conjunct) or z3.is_ge(conjunct) or z3.is_not(conjunct):
pred_list = [conjunct]
else:
raise RuntimeError(f"{conjunct} should be a conjunction.")
for pred in pred_list:
assert isinstance(pred, sympy.core.relational.Relational)
lhs = self._build_affine_expr(pred.lhs)
rhs = self._build_affine_expr(pred.rhs)
for orig_pred in pred_list:
if z3.is_not(orig_pred):
pred = orig_pred.arg(0)
else:
pred = orig_pred
if isinstance(pred, sympy.Equality):
assert z3.is_eq(pred) or z3.is_le(pred) or z3.is_ge(pred), str(pred)
lhs = self._build_affine_expr(pred.arg(0))
rhs = self._build_affine_expr(pred.arg(1))
if z3.is_eq(pred):
cons = (lhs == rhs)
elif isinstance(pred, sympy.GreaterThan):
cons = (lhs >= rhs)
elif isinstance(pred, sympy.StrictGreaterThan):
cons = (lhs >= rhs + self.PRECISION)
elif isinstance(pred, sympy.LessThan):
cons = (lhs <= rhs)
elif isinstance(pred, sympy.StrictLessThan):
cons = (lhs <= rhs - self.PRECISION)
elif z3.is_ge(pred):
if not z3.is_not(orig_pred):
cons = (lhs >= rhs)
else: # !(lhs >= rhs) <=> (lhs < rhs) => lhs <= rhs - ð
cons = (lhs <= rhs - self.PRECISION)
elif z3.is_le(pred):
if not z3.is_not(orig_pred):
cons = (lhs <= rhs)
else: # !(lhs <= rhs) <=> (lhs > rhs) => lhs >= rhs + ð
cons = (lhs >= rhs + self.PRECISION)
else:
raise RuntimeError(f"Unsupprted relational expression {pred}")
raise RuntimeError(f"Unsupprted atomic predicate expression {pred}")
gp_cons = self._gp_model.addConstr(cons)
self._prev_candidate_constr.append(gp_cons)
def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
dnf = sympy.logic.boolalg.to_dnf(candidate, simplify=True)
if dnf is sympy.false:
return z3.unsat
# else:
if isinstance(dnf, sympy.Or):
conjunct_list = dnf.args
elif dnf is sympy.true or isinstance(dnf, sympy.Not) or isinstance(dnf, sympy.And) \
or isinstance(dnf, sympy.core.relational.Relational):
conjunct_list = [dnf]
else:
raise RuntimeError(f"Candiate formula {dnf} should have been converted to DNF.")
def _candidate_to_conjucts(self, candidate: z3.BoolRef):
init_path = [candidate]
stack = [init_path]
while stack:
curr_path = stack.pop() # remove this path
curr_node = curr_path.pop() # remove last node in this path
if z3.is_false(curr_node):
# the leaf node in this path is false. Skip
continue
elif z3.is_true(curr_node):
if not curr_path:
yield z3.BoolVal(True)
else:
yield z3.And(*curr_path)
elif z3.is_gt(curr_node) or z3.is_ge(curr_node) \
or z3.is_lt(curr_node) or z3.is_le(curr_node):
yield z3.And(*curr_path, curr_node)
elif z3.is_app_of(curr_node, z3.Z3_OP_ITE):
cond, left, right = curr_node.children()
l_path = curr_path.copy()
l_path.extend([cond, left])
r_path = curr_path.copy()
assert len(cond.children()) == 2
lhs, rhs = cond.children()
if z3.is_le(cond):
not_cond = lhs > rhs
elif z3.is_ge(cond):
not_cond = lhs < rhs
else:
raise RuntimeError(f"Unexpected condition {cond} for ITE")
r_path.extend([not_cond, right])
stack.append(r_path)
stack.append(l_path)
else:
raise RuntimeError(f"Candidate formula {curr_node} should have been converted to DNF.")
def check(self, candidate: z3.BoolRef) -> z3.CheckSatResult:
self._cexs.clear()
for conjunct in conjunct_list:
conjunct_iter = self._candidate_to_conjucts(candidate)
print("Checking candidate", flush=True)
for conjunct in conjunct_iter:
# print(".", end='', flush=True)
self._set_candidate(conjunct)
self._gp_model.optimize()
if self._gp_model.status == gp.GRB.INF_OR_UNBD:
......@@ -94,12 +128,23 @@ class DTreeGEMStanleyGurobiTeacher(GEMStanleyGurobiTeacher):
self._gp_model.optimize()
if self._gp_model.status in [gp.GRB.OPTIMAL, gp.GRB.SUBOPTIMAL]:
cex = tuple(self._old_state.x) + tuple(self._percept.x)
self._cexs.append(cex)
cex_list = []
for i in range(self._gp_model.SolCount):
self._gp_model.Params.SolutionNumber = i
cex = tuple(self._old_state.Xn) + tuple(self._percept.Xn)
cex_list.append(cex)
filtered_cex_list = [cex for cex in cex_list
if not self.is_spurious_example(self.state_dim, self.perc_dim, conjunct, cex)]
if filtered_cex_list:
self._cexs.extend(filtered_cex_list)
else:
raise RuntimeError(f"Only found spurious cexs {cex_list} for the conjuct {conjunct}.")
elif self._gp_model.status == gp.GRB.INFEASIBLE:
continue
else:
return z3.unknown
print("Done")
if len(self._cexs) > 0:
return z3.sat
else:
......@@ -107,18 +152,10 @@ class DTreeGEMStanleyGurobiTeacher(GEMStanleyGurobiTeacher):
def test_dtree_gem_stanley_gurobi_teacher():
a_mat = np.array([[0., -1., 0.],
[0., 0., -1.]])
b_vec = np.zeros(2)
coeff_mat = np.array([
[-1, -1],
[1, -1],
[-1, 0],
[0, 1],
])
cut_vec = np.array([1, 2, 3, 4])
candidate = (a_mat, b_vec, coeff_mat, cut_vec)
candidate = sympy.sympify("ITE(x_0 >= 0.2, x_1 <= 1, 0.5*x_0 + x_1 > 3)")
x_0, x_1 = z3.Reals("x_0 x_1")
candidate = z3.If(x_0 >= 0.2,
z3.If(x_1 <= 1, z3.BoolVal(True), 0.5*x_0 + x_1 <= 3),
0.5*x_0 + x_1 > 3)
teacher = DTreeGEMStanleyGurobiTeacher(norm_ord=2)
print(teacher.check(candidate))
......
from typing import Optional, Tuple
import dreal
import gurobipy as gp
from gurobipy import GRB
import numpy as np
import sympy
import z3
from teacher_base import GurobiTeacherBase, DRealTeacherBase, SymPyTeacherBase
......@@ -34,6 +35,12 @@ PERC_GT = np.array([[0., -1., 0.],
[0., 0., -1.]], float)
def z3_float64_const_to_real(v: float) -> z3.RatNumRef:
return z3.simplify(
z3.fpToReal(z3.FPVal(v, z3.Float64()))
)
class GEMStanleyDRealTeacher(DRealTeacherBase):
def __init__(self, name="gem_stanley", norm_ord=2, delta=0.001) -> None:
......@@ -83,13 +90,32 @@ class GEMStanleyDRealTeacher(DRealTeacherBase):
class GEMStanleyGurobiTeacher(GurobiTeacherBase):
@staticmethod
def is_spurious_example(state_dim: int, perc_dim: int, candidate: z3.BoolRef, cex: Tuple[float, ...]) -> bool:
state_subs_map = [(z3.Real(f"x_{i}"), z3_float64_const_to_real(cex[i])) for i in range(state_dim)]
perc_subs_map = [(z3.Real(f"z_{i}"), z3_float64_const_to_real(cex[i+state_dim])) for i in range(perc_dim)]
sub_map = state_subs_map + perc_subs_map
val = z3.simplify(z3.substitute(candidate, *sub_map))
assert z3.is_bool(val)
if z3.is_false(val):
return True
elif z3.is_true(val):
return False
else:
raise RuntimeError(f"Cannot validate negative example {cex} by substitution")
def __init__(self, name="gem_stanley", norm_ord=2) -> None:
super().__init__(name=name,
state_dim=3, perc_dim=2, ctrl_dim=1, norm_ord=norm_ord)
def is_positive_example(self, ex) -> bool:
def is_positive_example(self, ex) -> Optional[bool]:
assert len(ex) == self.state_dim + self.perc_dim
state_arr = np.asfarray(ex[0:self.state_dim])
if not (np.all(self._old_state.lb <= state_arr) and
np.all(state_arr <= self._old_state.ub)):
return False
def g(cte, phi):
error = phi + np.arctan(K_P*cte/FORWARD_VEL)
steer = np.clip(error, -STEERING_LIM, STEERING_LIM)
......@@ -159,6 +185,7 @@ class GEMStanleyGurobiTeacher(GurobiTeacherBase):
m.addConstr(new_y == old_y + FORWARD_VEL*CYCLE_TIME*sin_yaw_steer)
m.addConstr(new_yaw ==
old_yaw + sin_steer*FORWARD_VEL*CYCLE_TIME/WHEEL_BASE)
m.update()
def _add_unsafe(self) -> None:
assert PERC_GT.shape == (self.perc_dim, self.state_dim)
......@@ -179,6 +206,7 @@ class GEMStanleyGurobiTeacher(GurobiTeacherBase):
m.addConstr(new_lya_val == gp.norm(new_truth, float(self._norm_ord)))
m.addConstr(new_lya_val >= old_lya_val, name="Non-decreasing Error") # Tracking error is non-decreasing
m.update()
def _add_objective(self) -> None:
# Variable Aliases
......@@ -193,6 +221,7 @@ class GEMStanleyGurobiTeacher(GurobiTeacherBase):
m.addConstr(norm_var == gp.norm(z_diff, float(self._norm_ord)))
m.setObjective(norm_var, gp.GRB.MINIMIZE)
m.update()
class GEMStanleySymPyTeacher(SymPyTeacherBase):
......
import abc
import itertools
from typing import Dict, Hashable, List, Literal, Sequence, Tuple
from typing import Dict, Hashable, List, Literal, Optional, Sequence, Tuple
import gurobipy as gp
import numpy as np
......@@ -14,11 +14,11 @@ class TeacherBase(abc.ABC):
pass
@abc.abstractmethod
def is_positive_example(self, ex) -> bool:
def is_positive_example(self, ex) -> Optional[bool]:
raise NotImplementedError
@abc.abstractmethod
def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
def check(self, candidate) -> z3.CheckSatResult:
raise NotImplementedError
@abc.abstractmethod
......@@ -90,21 +90,8 @@ class DRealTeacherBase(TeacherBase):
def ctrl_dim(self) -> int:
return len(self._control)
def _gen_cand_pred(self, candidate: sympy.logic.boolalg.Boolean):
raise NotImplementedError("TODO convert from sympy")
a_mat, b_vec, coeff_mat, cut_vec = candidate # TODO parse candidate
# VARIABLE ALIASES
x = self._old_state
z = self._percept
trans_vars = [z_i - z_hat_i for z_i, z_hat_i in zip(z, self.affine_trans_exprs(x, a_mat, b_vec))]
print("Transformed vars:", trans_vars)
polytope_exprs = self.affine_trans_exprs(trans_vars, coeff_mat)
return dreal.And(
*(poly_expr <= cut for poly_expr, cut in zip(polytope_exprs, cut_vec))
)
def _gen_cand_pred(self, candidate):
raise NotImplementedError(f"TODO convert {candidate} to dReal formula")
def _set_var_bound(self, smt_vars: List[dreal.Variable], lb: Sequence[float], ub: Sequence[float]) -> None:
assert len(lb) == len(smt_vars)
......@@ -115,7 +102,7 @@ class DRealTeacherBase(TeacherBase):
def set_old_state_bound(self, lb: Sequence[float], ub: Sequence[float]) -> None:
self._set_var_bound(self._old_state, lb, ub)
def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
def check(self, candidate) -> z3.CheckSatResult:
bound_preds = [
dreal.And(lb_i <= var_i, var_i <= ub_i)
for var_i, (lb_i, ub_i) in self._var_bounds.items()
......@@ -213,8 +200,9 @@ class GurobiTeacherBase(TeacherBase):
def set_old_state_bound(self, lb, ub) -> None:
self._old_state.lb = lb
self._old_state.ub = ub
self._gp_model.update()
def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
def check(self, candidate) -> z3.CheckSatResult:
raise NotImplementedError
def dump_system_encoding(self, basename: str = "") -> None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment