Skip to content
Snippets Groups Projects
Commit b4a41659 authored by chsieh16's avatar chsieh16
Browse files

Add threhold in candidates and assertion to check negative examples

parent ddbe0a57
No related branches found
No related tags found
No related merge requests found
...@@ -32,11 +32,11 @@ def test_synth_dtree(): ...@@ -32,11 +32,11 @@ def test_synth_dtree():
positive_examples = load_positive_examples( positive_examples = load_positive_examples(
"data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle") "data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle")
#positive_examples = positive_examples[:20:] # Select only first few examples # positive_examples = positive_examples[:20:] # Select only first few examples
ex_dim = len(positive_examples[0]) ex_dim = len(positive_examples[0])
print("#examples: %d" % len(positive_examples)) print("#examples: %d" % len(positive_examples))
print("Dimension of each example: %d" % ex_dim) print("Dimension of each example: %d" % ex_dim)
assert all(len(ex) == ex_dim and not any(np.isnan(ex)) assert all(len(ex) == ex_dim and not any(np.isnan(ex))
for ex in positive_examples) for ex in positive_examples)
...@@ -56,11 +56,11 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10): ...@@ -56,11 +56,11 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
[0., 0., -1.]]) [0., 0., -1.]])
b_vec_0 = np.zeros(2) b_vec_0 = np.zeros(2)
learner.set_grammar([(a_mat_0, b_vec_0)]) learner.set_grammar([(a_mat_0, b_vec_0)])
#writing positive examples in csv file # writing positive examples in csv file
f = open("positive_sample.csv", 'w') f = open("positive_sample.csv", 'w')
data_out = csv.writer(f) data_out = csv.writer(f)
for ps in positive_examples: for ps in positive_examples:
data_out.writerow(itertools.chain(ps,["true"])) data_out.writerow(itertools.chain(ps, ["true"]))
f.close() f.close()
learner.add_positive_examples(*positive_examples) learner.add_positive_examples(*positive_examples)
...@@ -78,7 +78,7 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10): ...@@ -78,7 +78,7 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
return return
print(f"candidate DNF: {candidate_dnf}") print(f"candidate DNF: {candidate_dnf}")
past_candidate_list.append(candidate_dnf) past_candidate_list.append(candidate_dnf)
# QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES # QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
print(f"number of paths: {len(candidate_dnf)}") print(f"number of paths: {len(candidate_dnf)}")
...@@ -89,6 +89,7 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10): ...@@ -89,6 +89,7 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
if result == z3.sat: if result == z3.sat:
m = teacher.model() m = teacher.model()
assert len(m) > 0 assert len(m) > 0
assert validate_negative_examples(candidate, m)
negative_examples.extend(m) negative_examples.extend(m)
# TODO check if negative example state is spurious or true courterexample # TODO check if negative example state is spurious or true courterexample
elif result == z3.unsat: elif result == z3.unsat:
...@@ -100,9 +101,9 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10): ...@@ -100,9 +101,9 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
fneg = open("negative_sample.csv", 'w') fneg = open("negative_sample.csv", 'w')
data_out = csv.writer(fneg) data_out = csv.writer(fneg)
for pn in negative_examples: for pn in negative_examples:
data_out.writerow(itertools.chain(pn,["true"])) data_out.writerow(itertools.chain(pn, ["true"]))
fneg.close() fneg.close()
print(f"negative examples: {negative_examples}") print(f"negative examples: {negative_examples}")
if len(negative_examples) > 0: if len(negative_examples) > 0:
learner.add_negative_examples(*negative_examples) learner.add_negative_examples(*negative_examples)
...@@ -116,5 +117,16 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10): ...@@ -116,5 +117,16 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
return [] return []
def validate_negative_examples(candidate, neg_exs) -> bool:
a_mat, b_vec, coeff_mat, cut_vec = candidate
for ex in neg_exs:
z_diff = ex[3:5] - (a_mat @ ex[0:3] + b_vec)
if not np.all(coeff_mat @ z_diff <= cut_vec):
print(ex)
return False
return True
if __name__ == "__main__": if __name__ == "__main__":
test_synth_dtree() test_synth_dtree()
...@@ -28,7 +28,7 @@ class DTreeTeacherGEMStanley(GEMStanleyTeacher): ...@@ -28,7 +28,7 @@ class DTreeTeacherGEMStanley(GEMStanleyTeacher):
self._prev_candidate_constr.append(cons) self._prev_candidate_constr.append(cons)
# Constraints on values of z # Constraints on values of z
cons = m.addMConstr(coeff_mat, z_diff, '<', cut_vec) cons = m.addMConstr(coeff_mat, z_diff, '<', cut_vec - 10**-3)
self._prev_candidate_constr.append(cons) self._prev_candidate_constr.append(cons)
# L2-norm objective # L2-norm objective
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment