Skip to content
Snippets Groups Projects
Commit b4a41659 authored by chsieh16's avatar chsieh16
Browse files

Add threhold in candidates and assertion to check negative examples

parent ddbe0a57
No related branches found
No related tags found
No related merge requests found
......@@ -32,11 +32,11 @@ def test_synth_dtree():
positive_examples = load_positive_examples(
"data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle")
#positive_examples = positive_examples[:20:] # Select only first few examples
# positive_examples = positive_examples[:20:] # Select only first few examples
ex_dim = len(positive_examples[0])
print("#examples: %d" % len(positive_examples))
print("Dimension of each example: %d" % ex_dim)
assert all(len(ex) == ex_dim and not any(np.isnan(ex))
for ex in positive_examples)
......@@ -56,11 +56,11 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
[0., 0., -1.]])
b_vec_0 = np.zeros(2)
learner.set_grammar([(a_mat_0, b_vec_0)])
#writing positive examples in csv file
# writing positive examples in csv file
f = open("positive_sample.csv", 'w')
data_out = csv.writer(f)
for ps in positive_examples:
data_out.writerow(itertools.chain(ps,["true"]))
data_out.writerow(itertools.chain(ps, ["true"]))
f.close()
learner.add_positive_examples(*positive_examples)
......@@ -78,7 +78,7 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
return
print(f"candidate DNF: {candidate_dnf}")
past_candidate_list.append(candidate_dnf)
# QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
print(f"number of paths: {len(candidate_dnf)}")
......@@ -89,6 +89,7 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
if result == z3.sat:
m = teacher.model()
assert len(m) > 0
assert validate_negative_examples(candidate, m)
negative_examples.extend(m)
# TODO check if negative example state is spurious or true courterexample
elif result == z3.unsat:
......@@ -100,9 +101,9 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
fneg = open("negative_sample.csv", 'w')
data_out = csv.writer(fneg)
for pn in negative_examples:
data_out.writerow(itertools.chain(pn,["true"]))
data_out.writerow(itertools.chain(pn, ["true"]))
fneg.close()
print(f"negative examples: {negative_examples}")
if len(negative_examples) > 0:
learner.add_negative_examples(*negative_examples)
......@@ -116,5 +117,16 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
return []
def validate_negative_examples(candidate, neg_exs) -> bool:
a_mat, b_vec, coeff_mat, cut_vec = candidate
for ex in neg_exs:
z_diff = ex[3:5] - (a_mat @ ex[0:3] + b_vec)
if not np.all(coeff_mat @ z_diff <= cut_vec):
print(ex)
return False
return True
if __name__ == "__main__":
test_synth_dtree()
......@@ -28,7 +28,7 @@ class DTreeTeacherGEMStanley(GEMStanleyTeacher):
self._prev_candidate_constr.append(cons)
# Constraints on values of z
cons = m.addMConstr(coeff_mat, z_diff, '<', cut_vec)
cons = m.addMConstr(coeff_mat, z_diff, '<', cut_vec - 10**-3)
self._prev_candidate_constr.append(cons)
# L2-norm objective
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment