Skip to content
Snippets Groups Projects
dtree_synth_gem_stanley.py 4.3 KiB
Newer Older
#!/usr/bin/env python3

import itertools
import json
import os
import pathlib
import pickle
from typing import Dict, Hashable, Literal

import numpy as np

from dtree_synth import DataSet, search_part, synth_dtree_per_part
from dtree_learner import DTreeLearner
from gem_stanley_teacher import DTreeGEMStanleyGurobiStabilityTeacher
from teacher_base import TeacherBase

def load_partitioned_examples(file_name: str, teacher: DTreeGEMStanleyGurobiStabilityTeacher, partition) \
        -> Dict[Hashable, DataSet]:
    print("Loading examples")
    with open(file_name, "rb") as pickle_file_io:
        pkl_data = pickle.load(pickle_file_io)

    truth_samples_seq = pkl_data["truth_samples"]

    bound_list = list(list(zip(x_arr[:-1], x_arr[1:])) for x_arr in partition)
    ret = {part: DataSet() for part in itertools.product(*bound_list)}
    # Convert from sampled states and percepts to positive and negative examples for learning
    num_excl_samples = 0
    for _, ss in truth_samples_seq:
        for s in ss:
            state = s[0:teacher.state_dim]
            part = search_part(partition, state)
            if part is None:
                num_excl_samples += 1
                continue
            if np.any(np.isnan(s)):
                ret[part].num_nan_dps += 1
            # TODO note, isin't variable _ being passed to is_positive.
            # isint _ ground truth for each s in ss?
            elif teacher.is_safe_state(s):
                ret[part].safe_dps.append(s)
            else:
                ret[part].unsafe_dps.append(s)
    print("# samples not in any selected parts:", num_excl_samples)
    return ret


def main(dom: Literal["concat", "diff"], bnd_relax: float):
    X_LIM = np.inf
    X_ARR = np.array([-X_LIM, X_LIM])

    Y_LIM = 1.2
    NUM_Y_PARTS = 4
    Y_ARR = np.linspace(-Y_LIM, Y_LIM, NUM_Y_PARTS + 1)

    YAW_LIM = np.pi / 12
    NUM_YAW_PARTS = 10
    YAW_ARR = np.linspace(-YAW_LIM, YAW_LIM, NUM_YAW_PARTS + 1)

    PARTITION = (X_ARR, Y_ARR, YAW_ARR)

    PKL_FILE_PATH = "data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle"
    NUM_MAX_ITER = 500
    FEATURE_DOMAIN = dom
    ULT_BOUND = bnd_relax

    def teacher_builder():
        return DTreeGEMStanleyGurobiStabilityTeacher(norm_ord=NORM_ORD, ultimate_bound=ULT_BOUND)

    def learner_builder(teacher: TeacherBase):
        learner = DTreeLearner(state_dim=teacher.state_dim,
                               perc_dim=teacher.perc_dim, timeout=20000)
        learner.set_grammar([(DTreeGEMStanleyGurobiStabilityTeacher.PERC_GT, np.zeros(2))], FEATURE_DOMAIN)
        return learner

    teacher = teacher_builder()
    part_to_examples = load_partitioned_examples(
        file_name=PKL_FILE_PATH,
        teacher=teacher, partition=PARTITION
    )

    # Print statistics about training data points
    print("#"*80)
    print("Parts with unsafe data points:")
    for i, (part, dataset) in enumerate(part_to_examples.items()):
        safe_dps, unsafe_dps, num_nan = dataset.safe_dps, dataset.unsafe_dps, dataset.num_nan_dps

        lb, ub = np.asfarray(part).T
        lb[2] = np.rad2deg(lb[2])
        ub[2] = np.rad2deg(ub[2])

        if len(unsafe_dps) > 0:
            print(f"Part Index {i}:", f"y in [{lb[1]:.03}, {ub[1]:.03}] (m);", f"θ in [{lb[2]:.03}, {ub[2]:.03}] (deg);",
                  f"# safe: {len(safe_dps)}", f"# unsafe: {len(unsafe_dps):03}", f"# NaN: {num_nan}")

    result = synth_dtree_per_part(
        part_to_examples,
        teacher_builder,
        learner_builder,
        num_max_iter=NUM_MAX_ITER,
        ult_bound=ULT_BOUND,
        feature_domain=FEATURE_DOMAIN
    )

    with open(f"out/dtree_synth.{NUM_Y_PARTS}x{NUM_YAW_PARTS}.out.json", "w") as f:
        json.dump(result, f)


aastorg2's avatar
aastorg2 committed
def clean_out_dir(dirname:str):
    mkdir = "mkdir "+dirname
    cptodir = "cp -r out/ " + dirname
    list_cmds = [mkdir, cptodir, "rm out/*.data","rm out/*.json","rm out/pre\.*"]
    
    for cmd in list_cmds:
        os.system(cmd)

if __name__ == "__main__":
#    main("diff", 0.0)  # EXPERIMENT: 2D, strict
#    clean_out_dir("diff0_0")
#    main("diff", 1.0)  # EXPERIMENT: 2D, relaxed
#    clean_out_dir("diff1_0")
#    main("concat", 0.0)  # EXPERIMENT: 4D, strict
#    clean_out_dir("concat0_0")
aastorg2's avatar
aastorg2 committed
    main("concat", 1.0)  # EXPERIMENT: 4D, relaxed
    clean_out_dir("concat1_0")