Newer
Older
import pickle
from typing import List, Tuple
import numpy as np
import z3
from dtree_learner import DTreeLearner as Learner
from dtree_teacher_gem_stanley import DTreeTeacherGEMStanley as Teacher
def load_positive_examples(file_name: str) -> List[Tuple[float, ...]]:
with open(file_name, "rb") as pickle_file_io:
pkl_data = pickle.load(pickle_file_io)
truth_samples_seq = pkl_data["truth_samples"]
i_th = 0 # select only the i-th partition
truth_samples_seq = truth_samples_seq[i_th:i_th+1]
print("Representative point in partition:", truth_samples_seq[0][0])
truth_samples_seq = [(t, [s for s in raw_samples if not any(np.isnan(s))])
for t, raw_samples in truth_samples_seq]
# Convert from sampled states and percepts to positive examples for learning
return [
s for _, samples in truth_samples_seq for s in samples
]
def test_synth_dtree():
positive_examples = load_positive_examples(
"data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle")
# positive_examples = positive_examples[:20:] # Select only first few examples
ex_dim = len(positive_examples[0])
print("#examples: %d" % len(positive_examples))
print("Dimension of each example: %d" % ex_dim)
assert all(len(ex) == ex_dim and not any(np.isnan(ex))
for ex in positive_examples)
teacher = Teacher()
assert teacher.state_dim + teacher.perc_dim == ex_dim
# 0.0 <= x <= 30.0 and -1.0 <= y <= 0.9 and 0.2 <= theta <= 0.22
teacher.set_old_state_bound(lb=[0.0, -1.0, 0.2], ub=[30.0, -0.9, 0.22])
synth_dtree(positive_examples, teacher, num_max_iterations=50)
def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
learner = Learner(state_dim=teacher.state_dim,
perc_dim=teacher.perc_dim, timeout=20000)
a_mat_0 = np.array([[0., -1., 0.],
[0., 0., -1.]])
b_vec_0 = np.zeros(2)
learner.set_grammar([(a_mat_0, b_vec_0)])
# writing positive examples in csv file
f = open("positive_sample.csv", 'w')
data_out = csv.writer(f)
for ps in positive_examples:
data_out.writerow(itertools.chain(ps, ["true"]))
learner.add_positive_examples(*positive_examples)
past_candidate_list = []
for k in range(num_max_iterations):
print(f"Iteration {k}:", sep='')
print("learning ....")
candidate_dnf = learner.learn()
print("done learning")
if not candidate_dnf: # learning FAILED
print("Learning Failed.")
return
print(f"candidate DNF: {candidate_dnf}")
past_candidate_list.append(candidate_dnf)
# QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
negative_examples = []
for candidate in candidate_dnf:
result = teacher.check(candidate)
print(result)
if result == z3.sat:
m = teacher.model()
assert len(m) > 0
assert validate_negative_examples(candidate, m)
negative_examples.extend(m)
# TODO check if negative example state is spurious or true courterexample
elif result == z3.unsat:
continue
else:
print("Reason Unknown", teacher.reason_unknown())
return past_candidate_list
fneg = open("negative_sample.csv", 'w')
data_out = csv.writer(fneg)
for pn in negative_examples:
data_out.writerow(itertools.chain(pn, ["true"]))
print(f"negative examples: {negative_examples}")
if len(negative_examples) > 0:
learner.add_negative_examples(*negative_examples)
else:
file = open('winnerDnf', 'wb')
pickle.dump(candidate_dnf, file)
print("we are done!")
return past_candidate_list
print("Reached max iteration %d." % num_max_iterations)
def validate_negative_examples(candidate, neg_exs) -> bool:
a_mat, b_vec, coeff_mat, cut_vec = candidate
for ex in neg_exs:
z_diff = ex[3:5] - (a_mat @ ex[0:3] + b_vec)
if not np.all(coeff_mat @ z_diff <= cut_vec):
print(ex)
return False
return True
if __name__ == "__main__":
test_synth_dtree()