Skip to content
Snippets Groups Projects
Commit e5186b45 authored by aastorg2's avatar aastorg2
Browse files

Merge branch 'main' of gitlab.engr.illinois.edu:chsieh16/cs598mp-fall2021-proj into main

parents 164fc16b b4a41659
No related branches found
No related tags found
No related merge requests found
......@@ -33,11 +33,11 @@ def test_synth_dtree():
positive_examples = load_positive_examples(
"data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle")
#positive_examples = positive_examples[:20:] # Select only first few examples
# positive_examples = positive_examples[:20:] # Select only first few examples
ex_dim = len(positive_examples[0])
print("#examples: %d" % len(positive_examples))
print("Dimension of each example: %d" % ex_dim)
assert all(len(ex) == ex_dim and not any(np.isnan(ex))
for ex in positive_examples)
......@@ -57,11 +57,11 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
[0., 0., -1.]])
b_vec_0 = np.zeros(2)
learner.set_grammar([(a_mat_0, b_vec_0)])
#writing positive examples in csv file
# writing positive examples in csv file
f = open("positive_sample.csv", 'w')
data_out = csv.writer(f)
for ps in positive_examples:
data_out.writerow(itertools.chain(ps,["true"]))
data_out.writerow(itertools.chain(ps, ["true"]))
f.close()
learner.add_positive_examples(*positive_examples)
......@@ -79,7 +79,7 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
return
print(f"candidate DNF: {candidate_dnf}")
past_candidate_list.append(candidate_dnf)
# QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
print(f"number of paths: {len(candidate_dnf)}")
......@@ -91,6 +91,7 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
if result == z3.sat:
m = teacher.model()
assert len(m) > 0
assert validate_negative_examples(candidate, m)
negative_examples.extend(m)
# TODO check if negative example state is spurious or true courterexample
elif result == z3.unsat:
......@@ -102,9 +103,9 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
fneg = open("negative_sample.csv", 'w')
data_out = csv.writer(fneg)
for pn in negative_examples:
data_out.writerow(itertools.chain(pn,["true"]))
data_out.writerow(itertools.chain(pn, ["true"]))
fneg.close()
print(f"negative examples: {negative_examples}")
if len(negative_examples) > 0:
learner.add_negative_examples(*negative_examples)
......@@ -121,5 +122,16 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
return []
def validate_negative_examples(candidate, neg_exs) -> bool:
a_mat, b_vec, coeff_mat, cut_vec = candidate
for ex in neg_exs:
z_diff = ex[3:5] - (a_mat @ ex[0:3] + b_vec)
if not np.all(coeff_mat @ z_diff <= cut_vec):
print(ex)
return False
return True
if __name__ == "__main__":
test_synth_dtree()
......@@ -28,7 +28,7 @@ class DTreeTeacherGEMStanley(GEMStanleyTeacher):
self._prev_candidate_constr.append(cons)
# Constraints on values of z
cons = m.addMConstr(coeff_mat, z_diff, '<', cut_vec)
cons = m.addMConstr(coeff_mat, z_diff, '<', cut_vec - 10**-3)
self._prev_candidate_constr.append(cons)
# L2-norm objective
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment