#!/usr/bin/env python3 import itertools import json import matplotlib.pyplot as plt import pathlib import pickle import traceback from typing import Dict, Hashable, List, Tuple import numpy as np import z3 from dtree_learner import DTreeLearner as Learner from dtree_teacher_gem_stanley import DTreeGEMStanleyGurobiTeacher as Teacher def load_examples(file_name: str, spec) -> Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]]]: print("Loading examples") with open(file_name, "rb") as pickle_file_io: pkl_data = pickle.load(pickle_file_io) truth_samples_seq = pkl_data["truth_samples"] # Convert from sampled states and percepts to positive and negative examples for learning pos_exs, neg_exs, num_excl_exs = [], [], 0 for _, ss in truth_samples_seq: for s in ss: if np.any(np.isnan(s)): num_excl_exs += 1 elif spec(s): pos_exs.append(s) else: neg_exs.append(s) print("# Excluded NaN examples:", num_excl_exs) return pos_exs, neg_exs def synth_dtree(learner: Learner, teacher: Teacher, num_max_iterations: int = 10): past_candidate_list = [] for k in range(num_max_iterations): print("="*80) print(f"Iteration {k}:", sep='') print("learning ....") candidate = learner.learn() print("done learning") print(f"candidate: {candidate}") past_candidate_list.append(candidate) # QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES result = teacher.check(candidate) print(f"Satisfiability: {result}") if result == z3.sat: negative_examples = teacher.model() assert len(negative_examples) > 0 print(f"negative examples: {negative_examples}") # assert validate_cexs(teacher.state_dim, teacher.perc_dim, candidate, negative_examples) learner.add_negative_examples(*negative_examples) continue elif result == z3.unsat: print("we are done!") return True, (k, z3.simplify(candidate, arith_lhs=True).sexpr()) else: return False, f"Reason Unknown {teacher.reason_unknown()}" return False, f"Reached max iteration {num_max_iterations}." def validate_cexs(state_dim: int, perc_dim: int, candidate: z3.BoolRef, cexs: List[Tuple[float]]) -> bool: spurious_cexs = [cex for cex in cexs if Teacher.is_spurious_example(state_dim, perc_dim, candidate, cex)] if not spurious_cexs: return True else: print("Spurious CEXs:", *spurious_cexs, sep='\n') return False def search_part(partition, state): assert len(partition) == len(state) bounds = [] for sorted_list, v in zip(partition, state): i = np.searchsorted(sorted_list, v) if i == 0 or i == len(sorted_list): return None bounds.append((sorted_list[i-1], sorted_list[i])) return tuple(bounds) def load_partitioned_examples(file_name: str, teacher: Teacher, partition) \ -> Dict[Hashable, Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]], int]]: print("Loading examples") with open(file_name, "rb") as pickle_file_io: pkl_data = pickle.load(pickle_file_io) truth_samples_seq = pkl_data["truth_samples"] bound_list = list(list(zip(x_arr[:-1], x_arr[1:])) for x_arr in partition) ret = {part: [[], [], 0] for part in itertools.product(*bound_list)} # Convert from sampled states and percepts to positive and negative examples for learning num_excl_samples = 0 for _, ss in truth_samples_seq: for s in ss: state = s[0:teacher.state_dim] part = search_part(partition, state) if part is None: num_excl_samples += 1 continue if np.any(np.isnan(s)): ret[part][2] += 1 elif teacher.is_positive_example(s): ret[part][0].append(s) else: ret[part][1].append(s) print("# excluded samples:", num_excl_samples) return ret def main(): X_LIM = np.inf X_ARR = np.array([-X_LIM, X_LIM]) Y_LIM = 1.2 NUM_Y_PARTS = 4 Y_ARR = np.linspace(-Y_LIM, Y_LIM, NUM_Y_PARTS + 1) YAW_LIM = np.pi / 12 NUM_YAW_PARTS = 10 YAW_ARR = np.linspace(-YAW_LIM, YAW_LIM, NUM_YAW_PARTS + 1) PARTITION = (X_ARR, Y_ARR, YAW_ARR) PKL_FILE_PATH = "data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle" NORM_ORD = 1 NUM_MAX_ITER = 500 teacher = Teacher(norm_ord=NORM_ORD) part_to_examples = load_partitioned_examples( file_name=PKL_FILE_PATH, teacher=teacher, partition=PARTITION ) result = [] for i, (part, (pos_exs, neg_exs, num_nan)) in enumerate(part_to_examples.items()): print("#"*80) print(f"# positive: {len(pos_exs)}; " f"# negative: {len(neg_exs)}; " f"# NaN: {num_nan}") lb, ub = np.asfarray(part).T # XXX Create new teacher and learner for each part to avoid solutions from other parts # TODO incremental solving teacher = Teacher(norm_ord=NORM_ORD) teacher.set_old_state_bound(lb=lb, ub=ub) learner = Learner(state_dim=teacher.state_dim, perc_dim=teacher.perc_dim, timeout=20000) learner.set_grammar([(Teacher.PERC_GT, np.zeros(2))]) learner.add_positive_examples(*pos_exs) learner.add_negative_examples(*neg_exs) try: found, ret = synth_dtree(learner, teacher, num_max_iterations=NUM_MAX_ITER) print(f"Found? {found}") if found: k, expr = ret result.append({"part": part, "status": "found", "result": {"k": k, "formula": expr}}) else: result.append({"part": part, "status": "not found", "result": ret}) except KeyboardInterrupt: print("User pressed Ctrl+C. Skip all remaining partition.") break # NOTE finally block is executed before break except Exception as e: result.append({"part": part, "status": "exception", "result": traceback.format_exc()}) print(e) finally: data_file = pathlib.Path("out/pre.data") data_file.rename(f"out/part-{i:03}-pre.data") del teacher del learner with open(f"out/dtree_synth.{NUM_Y_PARTS}x{NUM_YAW_PARTS}.out.json", "w") as f: json.dump(result, f) if __name__ == "__main__": main()