diff --git a/dtree_synth.py b/dtree_synth.py index f6066a1be677020ce8bb9a216f83b23479a8e893..9b0e344b7935fed66dfab414d15b96c872132a9c 100644 --- a/dtree_synth.py +++ b/dtree_synth.py @@ -33,11 +33,11 @@ def test_synth_dtree(): positive_examples = load_positive_examples( "data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle") - #positive_examples = positive_examples[:20:] # Select only first few examples + # positive_examples = positive_examples[:20:] # Select only first few examples ex_dim = len(positive_examples[0]) print("#examples: %d" % len(positive_examples)) print("Dimension of each example: %d" % ex_dim) - + assert all(len(ex) == ex_dim and not any(np.isnan(ex)) for ex in positive_examples) @@ -57,11 +57,11 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10): [0., 0., -1.]]) b_vec_0 = np.zeros(2) learner.set_grammar([(a_mat_0, b_vec_0)]) - #writing positive examples in csv file + # writing positive examples in csv file f = open("positive_sample.csv", 'w') data_out = csv.writer(f) for ps in positive_examples: - data_out.writerow(itertools.chain(ps,["true"])) + data_out.writerow(itertools.chain(ps, ["true"])) f.close() learner.add_positive_examples(*positive_examples) @@ -79,7 +79,7 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10): return print(f"candidate DNF: {candidate_dnf}") - + past_candidate_list.append(candidate_dnf) # QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES print(f"number of paths: {len(candidate_dnf)}") @@ -91,6 +91,7 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10): if result == z3.sat: m = teacher.model() assert len(m) > 0 + assert validate_negative_examples(candidate, m) negative_examples.extend(m) # TODO check if negative example state is spurious or true courterexample elif result == z3.unsat: @@ -102,9 +103,9 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10): fneg = open("negative_sample.csv", 'w') data_out = csv.writer(fneg) for pn in negative_examples: - data_out.writerow(itertools.chain(pn,["true"])) + data_out.writerow(itertools.chain(pn, ["true"])) fneg.close() - + print(f"negative examples: {negative_examples}") if len(negative_examples) > 0: learner.add_negative_examples(*negative_examples) @@ -121,5 +122,16 @@ def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10): return [] +def validate_negative_examples(candidate, neg_exs) -> bool: + a_mat, b_vec, coeff_mat, cut_vec = candidate + + for ex in neg_exs: + z_diff = ex[3:5] - (a_mat @ ex[0:3] + b_vec) + if not np.all(coeff_mat @ z_diff <= cut_vec): + print(ex) + return False + return True + + if __name__ == "__main__": test_synth_dtree() diff --git a/dtree_teacher_gem_stanley.py b/dtree_teacher_gem_stanley.py index 0cd904d10087bea5eba6e88a20e2e81a53f7963a..af00ec585e659259ea39d100fe16833438a3793b 100644 --- a/dtree_teacher_gem_stanley.py +++ b/dtree_teacher_gem_stanley.py @@ -28,7 +28,7 @@ class DTreeTeacherGEMStanley(GEMStanleyTeacher): self._prev_candidate_constr.append(cons) # Constraints on values of z - cons = m.addMConstr(coeff_mat, z_diff, '<', cut_vec) + cons = m.addMConstr(coeff_mat, z_diff, '<', cut_vec - 10**-3) self._prev_candidate_constr.append(cons) # L2-norm objective