#!/usr/bin/env python3 import itertools import json import os import pathlib import pickle from typing import Dict, Hashable, Literal import numpy as np from dtree_synth import DataSet, search_part, synth_dtree_per_part from dtree_learner import DTreeLearner from gem_stanley_teacher import DTreeGEMStanleyGurobiStabilityTeacher from teacher_base import TeacherBase def load_partitioned_examples(file_name: str, teacher: DTreeGEMStanleyGurobiStabilityTeacher, partition) \ -> Dict[Hashable, DataSet]: print("Loading examples") with open(file_name, "rb") as pickle_file_io: pkl_data = pickle.load(pickle_file_io) truth_samples_seq = pkl_data["truth_samples"] bound_list = list(list(zip(x_arr[:-1], x_arr[1:])) for x_arr in partition) ret = {part: DataSet() for part in itertools.product(*bound_list)} # Convert from sampled states and percepts to positive and negative examples for learning num_excl_samples = 0 for _, ss in truth_samples_seq: for s in ss: state = s[0:teacher.state_dim] part = search_part(partition, state) if part is None: num_excl_samples += 1 continue if np.any(np.isnan(s)): ret[part].num_nan_dps += 1 # TODO note, isin't variable _ being passed to is_positive. # isint _ ground truth for each s in ss? elif teacher.is_safe_state(s): ret[part].safe_dps.append(s) else: ret[part].unsafe_dps.append(s) print("# samples not in any selected parts:", num_excl_samples) return ret def main(dom: Literal["concat", "diff"], bnd_relax: float): X_LIM = np.inf X_ARR = np.array([-X_LIM, X_LIM]) Y_LIM = 1.2 NUM_Y_PARTS = 4 Y_ARR = np.linspace(-Y_LIM, Y_LIM, NUM_Y_PARTS + 1) YAW_LIM = np.pi / 12 NUM_YAW_PARTS = 10 YAW_ARR = np.linspace(-YAW_LIM, YAW_LIM, NUM_YAW_PARTS + 1) PARTITION = (X_ARR, Y_ARR, YAW_ARR) PKL_FILE_PATH = "data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle" NORM_ORD = 2 NUM_MAX_ITER = 500 FEATURE_DOMAIN = dom ULT_BOUND = bnd_relax def teacher_builder(): return DTreeGEMStanleyGurobiStabilityTeacher(norm_ord=NORM_ORD, ultimate_bound=ULT_BOUND) def learner_builder(teacher: TeacherBase): learner = DTreeLearner(state_dim=teacher.state_dim, perc_dim=teacher.perc_dim, timeout=20000) learner.set_grammar([(DTreeGEMStanleyGurobiStabilityTeacher.PERC_GT, np.zeros(2))], FEATURE_DOMAIN) return learner teacher = teacher_builder() part_to_examples = load_partitioned_examples( file_name=PKL_FILE_PATH, teacher=teacher, partition=PARTITION ) # Print statistics about training data points print("#"*80) print("Parts with unsafe data points:") for i, (part, dataset) in enumerate(part_to_examples.items()): safe_dps, unsafe_dps, num_nan = dataset.safe_dps, dataset.unsafe_dps, dataset.num_nan_dps lb, ub = np.asfarray(part).T lb[2] = np.rad2deg(lb[2]) ub[2] = np.rad2deg(ub[2]) if len(unsafe_dps) > 0: print(f"Part Index {i}:", f"y in [{lb[1]:.03}, {ub[1]:.03}] (m);", f"θ in [{lb[2]:.03}, {ub[2]:.03}] (deg);", f"# safe: {len(safe_dps)}", f"# unsafe: {len(unsafe_dps):03}", f"# NaN: {num_nan}") result = synth_dtree_per_part( part_to_examples, teacher_builder, learner_builder, num_max_iter=NUM_MAX_ITER, ult_bound=ULT_BOUND, feature_domain=FEATURE_DOMAIN ) with open(f"out/dtree_synth.{NUM_Y_PARTS}x{NUM_YAW_PARTS}.out.json", "w") as f: json.dump(result, f) def clean_out_dir(dirname:str): mkdir = "mkdir "+dirname cptodir = "cp -r out/ " + dirname list_cmds = [mkdir, cptodir, "rm out/*.data","rm out/*.json","rm out/pre\.*"] for cmd in list_cmds: os.system(cmd) if __name__ == "__main__": # main("diff", 0.0) # EXPERIMENT: 2D, strict # clean_out_dir("diff0_0") # main("diff", 1.0) # EXPERIMENT: 2D, relaxed # clean_out_dir("diff1_0") # main("concat", 0.0) # EXPERIMENT: 4D, strict # clean_out_dir("concat0_0") main("concat", 1.0) # EXPERIMENT: 4D, relaxed clean_out_dir("concat1_0")