#!/usr/bin/env python3 import itertools import json import matplotlib.pyplot as plt import pickle import traceback from typing import Dict, Hashable, List, Tuple import numpy as np import z3 from dtree_learner import DTreeLearner as Learner from dtree_teacher_gem_stanley import DTreeGEMStanleyGurobiTeacher as Teacher def load_examples(file_name: str, spec) -> Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]]]: print("Loading examples") with open(file_name, "rb") as pickle_file_io: pkl_data = pickle.load(pickle_file_io) truth_samples_seq = pkl_data["truth_samples"] # Convert from sampled states and percepts to positive and negative examples for learning pos_exs, neg_exs, num_excl_exs = [], [], 0 for _, ss in truth_samples_seq: for s in ss: if np.any(np.isnan(s)): num_excl_exs += 1 elif spec(s): pos_exs.append(s) else: neg_exs.append(s) print("# Exculded NaN examples:", num_excl_exs) return pos_exs, neg_exs def test_synth_dtree(): """ Test using filtered data where we know an abstraction exists. """ # Initialize Teacher teacher = Teacher(norm_ord=1, ultimate_bound=0.25) teacher.set_old_state_bound(lb=[0.0, -1.0, 0.2], ub=[32.0, -0.9, 0.22]) init_positive_examples, init_negative_examples = load_examples( "data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle", teacher.is_positive_example) print("# positive examples: %d" % len(init_positive_examples)) print("# negative examples: %d" % len(init_negative_examples)) ex_dim = len(init_positive_examples[0]) print("Dimension of each example: %d" % ex_dim) assert all(len(ex) == ex_dim and not any(np.isnan(ex)) for ex in init_positive_examples) assert teacher.state_dim + teacher.perc_dim == ex_dim # Initialize Learner learner = Learner(state_dim=teacher.state_dim, perc_dim=teacher.perc_dim, timeout=20000) a_mat_0 = Teacher.PERC_GT b_vec_0 = np.zeros(2) # Let z = [z_0, z_1] = [d, psi]; x = [x_0, x_1, x_2] = [x, y, theta] # a_mat_0 @ [x, y, theta] + b_vec_0 = [-y, -theta] # z - (a_mat_0 @ x + b_vec_0) = [d, psi] - [-y, -theta] = [d+y, psi+theta] defined as [fvar0_A0, fvar1_A0] learner.set_grammar([(a_mat_0, b_vec_0)]) learner.add_positive_examples(*init_positive_examples) learner.add_negative_examples(*init_negative_examples) synth_dtree(learner, teacher, num_max_iterations=2000) def synth_dtree(learner: Learner, teacher: Teacher, num_max_iterations: int = 10): past_candidate_list = [] for k in range(num_max_iterations): print("="*80) print(f"Iteration {k}:", sep='') print("learning ....") candidate = learner.learn() print("done learning") print(f"candidate: {candidate}") past_candidate_list.append(candidate) # QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES result = teacher.check(candidate) print(f"Satisfiability: {result}") if result == z3.sat: negative_examples = teacher.model() assert len(negative_examples) > 0 print(f"negative examples: {negative_examples}") # assert validate_cexs(teacher.state_dim, teacher.perc_dim, candidate, negative_examples) learner.add_negative_examples(*negative_examples) continue elif result == z3.unsat: print("we are done!") return True, (k, z3.simplify(candidate, arith_lhs=True).sexpr()) else: return False, f"Reason Unknown {teacher.reason_unknown()}" return False, f"Reached max iteration {num_max_iterations}." def validate_cexs(state_dim: int, perc_dim: int, candidate: z3.BoolRef, cexs: List[Tuple[float]]) -> bool: spurious_cexs = [cex for cex in cexs if Teacher.is_spurious_example(state_dim, perc_dim, candidate, cex)] if not spurious_cexs: return True else: print("Spurious CEXs:", *spurious_cexs, sep='\n') return False def search_part(partition, state): assert len(partition) == len(state) bounds = [] for sorted_list, v in zip(partition, state): i = np.searchsorted(sorted_list, v) if i == 0 or i == len(sorted_list): return None bounds.append( (sorted_list[i-1], sorted_list[i]) ) return tuple(bounds) def load_partitioned_examples(file_name: str, teacher: Teacher, partition) \ -> Dict[Hashable, Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]], int]]: print("Loading examples") with open(file_name, "rb") as pickle_file_io: pkl_data = pickle.load(pickle_file_io) truth_samples_seq = pkl_data["truth_samples"] bound_list = list(list(zip(x_arr[:-1], x_arr[1:])) for x_arr in partition) ret = {part: [[], [], 0] for part in itertools.product(*bound_list)} # Convert from sampled states and percepts to positive and negative examples for learning num_excl_samples = 0 for _, ss in truth_samples_seq: for s in ss: state = s[0:teacher.state_dim] part = search_part(partition, state) if part is None: num_excl_samples += 1 continue if np.any(np.isnan(s)): ret[part][2] += 1 elif teacher.is_positive_example(s): ret[part][0].append(s) else: ret[part][1].append(s) print("# excluded samples:", num_excl_samples) return ret def main(): X_LIM = np.inf X_ARR = np.array([-X_LIM, X_LIM]) Y_LIM = 1.2 NUM_Y_PARTS = 4 Y_ARR = np.linspace(-Y_LIM, Y_LIM, NUM_Y_PARTS + 1) YAW_LIM = np.pi / 12 NUM_YAW_PARTS = 10 YAW_ARR = np.linspace(-YAW_LIM, YAW_LIM, NUM_YAW_PARTS + 1) PARTITION = (X_ARR, Y_ARR, YAW_ARR) PKL_FILE_PATH = "data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle" NORM_ORD = 1 NUM_MAX_ITER = 500 teacher = Teacher(norm_ord=NORM_ORD) part_to_examples = load_partitioned_examples( file_name=PKL_FILE_PATH, teacher=teacher, partition=PARTITION ) result = [] for part, (pos_exs, neg_exs, num_nan) in part_to_examples.items(): print("#"*80) print(f"# positive: {len(pos_exs)}; " f"# negative: {len(neg_exs)}; " f"# NaN: {num_nan}") lb, ub = np.asfarray(part).T # XXX Create new teacher and learner for each part to avoid solutions from other parts # TODO incremental solving teacher = Teacher(norm_ord=NORM_ORD) teacher.set_old_state_bound(lb=lb, ub=ub) learner = Learner(state_dim=teacher.state_dim, perc_dim=teacher.perc_dim, timeout=20000) learner.set_grammar([(Teacher.PERC_GT, np.zeros(2))]) if pos_exs: pos_fv_arr = np.asfarray([learner._s2f_func(exs) for exs in pos_exs]) plt.scatter(pos_fv_arr[:, 0], pos_fv_arr[:, 1], c="g", marker="o") if neg_exs: neg_fv_arr = np.asfarray([learner._s2f_func(exs) for exs in neg_exs]) plt.scatter(neg_fv_arr[:, 0], neg_fv_arr[:, 1], c="r", marker="x") plt.show() continue # XXX Temporary skip learning and only plot feature vectors learner.add_positive_examples(*pos_exs) learner.add_negative_examples(*neg_exs) try: found, ret = synth_dtree(learner, teacher, num_max_iterations=NUM_MAX_ITER) print(f"Found? {found}") if found: k, expr = ret result.append({"part": part, "status": "found", "result": {"k": k, "formula": expr}}) else: result.append({"part": part, "status": "not found", "result": ret}) except Exception as e: result.append({"part": part, "status": "exception", "result": traceback.format_exc()}) print(e) finally: del teacher del learner with open(f"out/dtree_synth.{NUM_Y_PARTS}x{NUM_YAW_PARTS}.out.json", "w") as f: json.dump(result, f) if __name__ == "__main__": # test_synth_dtree() main()