Newer
Older
import pickle
from typing import List, Tuple
import numpy as np
import z3
from dtree_learner import DTreeLearner as Learner
from dtree_teacher_gem_stanley import DTreeGEMStanleyGurobiTeacher as Teacher
def load_positive_examples(file_name: str) -> List[Tuple[float, ...]]:
with open(file_name, "rb") as pickle_file_io:
pkl_data = pickle.load(pickle_file_io)
truth_samples_seq = pkl_data["truth_samples"]
i_th = 0 # select only the i-th partition
truth_samples_seq = truth_samples_seq[i_th:i_th+1]
print("Representative point in partition:", truth_samples_seq[0][0])
truth_samples_seq = [(t, [s for s in raw_samples if not any(np.isnan(s))])
for t, raw_samples in truth_samples_seq]
# Convert from sampled states and percepts to positive examples for learning
return [
s for _, samples in truth_samples_seq for s in samples
]
def test_synth_dtree():
positive_examples = load_positive_examples(
"data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle")
# positive_examples = positive_examples[:20:] # Select only first few examples
ex_dim = len(positive_examples[0])
print("#examples: %d" % len(positive_examples))
print("Dimension of each example: %d" % ex_dim)
assert all(len(ex) == ex_dim and not any(np.isnan(ex))
for ex in positive_examples)
teacher = Teacher()
assert teacher.state_dim + teacher.perc_dim == ex_dim
# 0.0 <= x <= 30.0 and -1.0 <= y <= 0.9 and 0.2 <= theta <= 0.22
teacher.set_old_state_bound(lb=[0.0, -1.0, 0.2], ub=[30.0, -0.9, 0.22])
synth_dtree(positive_examples, teacher, num_max_iterations=2000)
def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
learner = Learner(state_dim=teacher.state_dim,
perc_dim=teacher.perc_dim, timeout=20000)
a_mat_0 = np.array([[0., -1., 0.],
[0., 0., -1.]])
b_vec_0 = np.zeros(2)
# Let z = [z_0, z_1] = [d, psi]; x = [x_0, x_1, x_2] = [x, y, theta]
# a_mat_0 @ [x, y, theta] + b_vec_0 = [-y, -theta]
# z - (a_mat_0 @ x + b_vec_0) = [d, psi] - [-y, -theta] = [d+y, psi+theta] defined as [fvar0_A0, fvar1_A0]
learner.set_grammar([(a_mat_0, b_vec_0)])
learner.add_positive_examples(*positive_examples)
past_candidate_list = []
for k in range(num_max_iterations):
print(f"Iteration {k}:", sep='')
print("learning ....")
print(f"candidate: {candidate}")
past_candidate_list.append(candidate)
# QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
result = teacher.check(candidate)
if result == z3.sat:
negative_examples = teacher.model()
assert len(negative_examples) > 0
print(f"negative examples: {negative_examples}")
assert validate_cexs(teacher.state_dim, teacher.perc_dim, candidate, negative_examples)
learner.add_negative_examples(*negative_examples)
continue
elif result == z3.unsat:
print("we are done!")
return past_candidate_list
else:
print("Reason Unknown", teacher.reason_unknown())
return past_candidate_list
print("Reached max iteration %d." % num_max_iterations)
def validate_cexs(state_dim: int, perc_dim: int,
candidate: sympy.logic.boolalg.Boolean,
cexs: List[Tuple[float]]) -> bool:
spurious_cexs = []
for cex in cexs:
state_subs_map = [(f"x_{i}", cex[i]) for i in range(state_dim)]
perc_subs_map = [(f"z_{i}", cex[i+state_dim]) for i in range(perc_dim)]
sub_map = state_subs_map + perc_subs_map
val = candidate.subs(sub_map)
assert isinstance(val, sympy.logic.boolalg.BooleanAtom)
if val == sympy.false:
spurious_cexs.append(cex)
if not spurious_cexs:
return True
else:
print("Spurious CEXs:", *spurious_cexs, sep='\n')
return False
if __name__ == "__main__":
test_synth_dtree()