import itertools import pickle from typing import List, Tuple import numpy as np import z3 import csv from dtree_learner import DTreeLearner as Learner from dtree_teacher_gem_stanley import DTreeTeacherGEMStanley as Teacher def load_positive_examples(file_name: str) -> List[Tuple[float, ...]]: with open(file_name, "rb") as pickle_file_io: pkl_data = pickle.load(pickle_file_io) truth_samples_seq = pkl_data["truth_samples"] i_th = 0 # select only the i-th partition truth_samples_seq = truth_samples_seq[i_th:i_th+1] print("Representative point in partition:", truth_samples_seq[0][0]) truth_samples_seq = [(t, [s for s in raw_samples if not any(np.isnan(s))]) for t, raw_samples in truth_samples_seq] # Convert from sampled states and percepts to positive examples for learning return [ s for _, samples in truth_samples_seq for s in samples ] def test_synth_dtree(): positive_examples = load_positive_examples( "data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle") #positive_examples = positive_examples[:20:] # Select only first few examples ex_dim = len(positive_examples[0]) print("#examples: %d" % len(positive_examples)) print("Dimension of each example: %d" % ex_dim) assert all(len(ex) == ex_dim and not any(np.isnan(ex)) for ex in positive_examples) teacher = Teacher() assert teacher.state_dim + teacher.perc_dim == ex_dim # 0.0 <= x <= 30.0 and -1.0 <= y <= 0.9 and 0.2 <= theta <= 0.22 teacher.set_old_state_bound(lb=[0.0, -1.0, 0.2], ub=[30.0, -0.9, 0.22]) synth_dtree(positive_examples, teacher, num_max_iterations=50) def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10): learner = Learner(state_dim=teacher.state_dim, perc_dim=teacher.perc_dim, timeout=20000) a_mat_0 = np.array([[0., -1., 0.], [0., 0., -1.]]) b_vec_0 = np.zeros(2) learner.set_grammar([(a_mat_0, b_vec_0)]) #writing positive examples in csv file f = open("positive_sample.csv", 'w') data_out = csv.writer(f) for ps in positive_examples: data_out.writerow(itertools.chain(ps,["true"])) f.close() learner.add_positive_examples(*positive_examples) past_candidate_list = [] for k in range(num_max_iterations): print(f"Iteration {k}:", sep='') print("learning ....") candidate_dnf = learner.learn() print("done learning") if not candidate_dnf: # learning FAILED print("Learning Failed.") return print(f"candidate DNF: {candidate_dnf}") past_candidate_list.append(candidate_dnf) # QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES print(f"number of paths: {len(candidate_dnf)}") negative_examples = [] for candidate in candidate_dnf: result = teacher.check(candidate) print(result) if result == z3.sat: m = teacher.model() assert len(m) > 0 negative_examples.extend(m) # TODO check if negative example state is spurious or true courterexample elif result == z3.unsat: continue else: print("Reason Unknown", teacher.reason_unknown()) return past_candidate_list fneg = open("negative_sample.csv", 'w') data_out = csv.writer(fneg) for pn in negative_examples: data_out.writerow(itertools.chain(pn,["true"])) fneg.close() print(f"negative examples: {negative_examples}") if len(negative_examples) > 0: learner.add_negative_examples(*negative_examples) else: file = open('winnerDnf', 'wb') pickle.dump(candidate_dnf, file) print("we are done!") return past_candidate_list print("Reached max iteration %d." % num_max_iterations) return [] if __name__ == "__main__": test_synth_dtree()