import pickle from typing import List, Tuple import numpy as np import z3 from dtree_learner import DTreeLearner as Learner from dtree_teacher_gem_stanley import DTreeGEMStanleyGurobiTeacher as Teacher def load_examples(file_name: str, spec) -> Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]]]: print("Loading examples") with open(file_name, "rb") as pickle_file_io: pkl_data = pickle.load(pickle_file_io) truth_samples_seq = pkl_data["truth_samples"] # Convert from sampled states and percepts to positive and negative examples for learning pos_exs, neg_exs, num_excl_exs = [], [], 0 for _, ss in truth_samples_seq: for s in ss: ret = spec(s) if np.any(np.isnan(s)) or ret is None: num_excl_exs += 1 elif ret: pos_exs.append(s) else: neg_exs.append(s) print("# Exculded examples:", num_excl_exs) return pos_exs, neg_exs def test_synth_dtree(): teacher = Teacher(norm_ord=1) # 0.0 <= x <= 32.0 and -1.0 <= y <= -0.9 and 0.2 <= theta <= 0.22 teacher.set_old_state_bound(lb=[0.0, -1.0, 0.2], ub=[32.0, -0.9, 0.22]) # teacher.set_old_state_bound(lb=[0.0, -0.9, 2*np.pi/60], ub=[32.0, -0.6, 3*np.pi/60]) # teacher.set_old_state_bound(lb=[0.0, 0.3, 1*np.pi/60], ub=[32.0, 0.9, 5*np.pi/60]) positive_examples, negative_examples = load_examples( # "data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle", "data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle", teacher.is_positive_example) print("# positive examples: %d" % len(positive_examples)) print("# negative examples: %d" % len(negative_examples)) ex_dim = len(positive_examples[0]) print("Dimension of each example: %d" % ex_dim) assert all(len(ex) == ex_dim and not any(np.isnan(ex)) for ex in positive_examples) assert teacher.state_dim + teacher.perc_dim == ex_dim synth_dtree(positive_examples, negative_examples, teacher, num_max_iterations=2000) def synth_dtree(positive_examples, negative_examples, teacher, num_max_iterations: int = 10): learner = Learner(state_dim=teacher.state_dim, perc_dim=teacher.perc_dim, timeout=20000) a_mat_0 = np.array([[0., -1., 0.], [0., 0., -1.]]) b_vec_0 = np.zeros(2) # Let z = [z_0, z_1] = [d, psi]; x = [x_0, x_1, x_2] = [x, y, theta] # a_mat_0 @ [x, y, theta] + b_vec_0 = [-y, -theta] # z - (a_mat_0 @ x + b_vec_0) = [d, psi] - [-y, -theta] = [d+y, psi+theta] defined as [fvar0_A0, fvar1_A0] learner.set_grammar([(a_mat_0, b_vec_0)]) learner.add_positive_examples(*positive_examples) learner.add_negative_examples(*negative_examples) past_candidate_list = [] for k in range(num_max_iterations): print("="*80) print(f"Iteration {k}:", sep='') print("learning ....") candidate = learner.learn() print("done learning") print(f"candidate: {candidate}") past_candidate_list.append(candidate) # QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES result = teacher.check(candidate) print(f"Satisfiability: {result}") if result == z3.sat: negative_examples = teacher.model() # assert len(negative_examples) > 0 print(f"negative examples: {negative_examples}") # assert validate_cexs(teacher.state_dim, teacher.perc_dim, candidate, negative_examples) learner.add_negative_examples(*negative_examples) continue elif result == z3.unsat: print("we are done!") print(f"Simplified candidate: {z3.simplify(candidate)}") return past_candidate_list else: print("Reason Unknown", teacher.reason_unknown()) return past_candidate_list print("Reached max iteration %d." % num_max_iterations) return [] def validate_cexs(state_dim: int, perc_dim: int, candidate: z3.BoolRef, cexs: List[Tuple[float]]) -> bool: spurious_cexs = [cex for cex in cexs if Teacher.is_spurious_example(state_dim, perc_dim, candidate, cex)] if not spurious_cexs: return True else: print("Spurious CEXs:", *spurious_cexs, sep='\n') return False if __name__ == "__main__": test_synth_dtree()