Skip to content
Snippets Groups Projects
dtree_synth.py 6.73 KiB
Newer Older
#!/usr/bin/env python3

import itertools
import json
import matplotlib.pyplot as plt
import pathlib
import pickle
import traceback
from typing import Dict, Hashable, List, Tuple

import numpy as np
import z3
from dtree_learner import DTreeLearner as Learner
from dtree_teacher_gem_stanley import DTreeGEMStanleyGurobiTeacher as Teacher
def load_examples(file_name: str, spec) -> Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]]]:
    print("Loading examples")
    with open(file_name, "rb") as pickle_file_io:
        pkl_data = pickle.load(pickle_file_io)

    truth_samples_seq = pkl_data["truth_samples"]

    # Convert from sampled states and percepts to positive and negative examples for learning
    pos_exs, neg_exs, num_excl_exs = [], [], 0
    for _, ss in truth_samples_seq:
        for s in ss:
            if np.any(np.isnan(s)):
                num_excl_exs += 1
            elif spec(s):
                pos_exs.append(s)
            else:
                neg_exs.append(s)
chsieh16's avatar
chsieh16 committed
    print("# Excluded NaN examples:", num_excl_exs)
    return pos_exs, neg_exs
def synth_dtree(learner: Learner, teacher: Teacher, num_max_iterations: int = 10):
    past_candidate_list = []
    for k in range(num_max_iterations):
chsieh16's avatar
chsieh16 committed
        print("="*80)
        print(f"Iteration {k}:", sep='')
        print("learning ....")
aastorg2's avatar
aastorg2 committed

        candidate = learner.learn()
        print("done learning")
        print(f"candidate: {candidate}")
        past_candidate_list.append(candidate)
        # QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
        result = teacher.check(candidate)
chsieh16's avatar
chsieh16 committed
        print(f"Satisfiability: {result}")
        if result == z3.sat:
            negative_examples = teacher.model()
            assert len(negative_examples) > 0
            print(f"negative examples: {negative_examples}")

            # assert validate_cexs(teacher.state_dim, teacher.perc_dim, candidate, negative_examples)
            learner.add_negative_examples(*negative_examples)
            continue
        elif result == z3.unsat:
            print("we are done!")
            return True, (k, z3.simplify(candidate, arith_lhs=True).sexpr())
            return False, f"Reason Unknown {teacher.reason_unknown()}"
    return False, f"Reached max iteration {num_max_iterations}."
def validate_cexs(state_dim: int, perc_dim: int,
                  candidate: z3.BoolRef,
                  cexs: List[Tuple[float]]) -> bool:
    spurious_cexs = [cex for cex in cexs
                     if Teacher.is_spurious_example(state_dim, perc_dim, candidate, cex)]
    if not spurious_cexs:
        return True
    else:
        print("Spurious CEXs:", *spurious_cexs, sep='\n')
        return False
def search_part(partition, state):
    assert len(partition) == len(state)
    bounds = []
    for sorted_list, v in zip(partition, state):
        i = np.searchsorted(sorted_list, v)
        if i == 0 or i == len(sorted_list):
            return None
        bounds.append((sorted_list[i-1], sorted_list[i]))
    return tuple(bounds)


def load_partitioned_examples(file_name: str, teacher: Teacher, partition) \
        -> Dict[Hashable, Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]], int]]:
    print("Loading examples")
    with open(file_name, "rb") as pickle_file_io:
        pkl_data = pickle.load(pickle_file_io)

    truth_samples_seq = pkl_data["truth_samples"]

    bound_list = list(list(zip(x_arr[:-1], x_arr[1:])) for x_arr in partition)
    ret = {part: [[], [], 0] for part in itertools.product(*bound_list)}
    # Convert from sampled states and percepts to positive and negative examples for learning
    num_excl_samples = 0
    for _, ss in truth_samples_seq:
        for s in ss:
            state = s[0:teacher.state_dim]
            part = search_part(partition, state)
            if part is None:
                num_excl_samples += 1
                continue
            if np.any(np.isnan(s)):
                ret[part][2] += 1
            elif teacher.is_positive_example(s):
                ret[part][0].append(s)
            else:
                ret[part][1].append(s)
    print("# excluded samples:", num_excl_samples)
    return ret


def main():
    X_LIM = np.inf
    X_ARR = np.array([-X_LIM, X_LIM])

    Y_LIM = 1.2
    NUM_Y_PARTS = 4
    Y_ARR = np.linspace(-Y_LIM, Y_LIM, NUM_Y_PARTS + 1)

    YAW_LIM = np.pi / 12
    NUM_YAW_PARTS = 10
    YAW_ARR = np.linspace(-YAW_LIM, YAW_LIM, NUM_YAW_PARTS + 1)

    PARTITION = (X_ARR, Y_ARR, YAW_ARR)

    PKL_FILE_PATH = "data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle"
    NORM_ORD = 1
    NUM_MAX_ITER = 500

    teacher = Teacher(norm_ord=NORM_ORD)
    part_to_examples = load_partitioned_examples(
        file_name=PKL_FILE_PATH,
        teacher=teacher, partition=PARTITION
    )

    result = []
    for i, (part, (pos_exs, neg_exs, num_nan)) in enumerate(part_to_examples.items()):
        print("#"*80)
        print(f"# positive: {len(pos_exs)}; "
              f"# negative: {len(neg_exs)}; "
              f"# NaN: {num_nan}")
        lb, ub = np.asfarray(part).T

        # XXX Create new teacher and learner for each part to avoid solutions from other parts
        # TODO incremental solving
        teacher = Teacher(norm_ord=NORM_ORD)
        teacher.set_old_state_bound(lb=lb, ub=ub)

        learner = Learner(state_dim=teacher.state_dim,
                          perc_dim=teacher.perc_dim, timeout=20000)
        learner.set_grammar([(Teacher.PERC_GT, np.zeros(2))])

        learner.add_positive_examples(*pos_exs)
        learner.add_negative_examples(*neg_exs)
        try:
            found, ret = synth_dtree(learner, teacher,
                                     num_max_iterations=NUM_MAX_ITER)
            print(f"Found? {found}")
            if found:
                k, expr = ret
                result.append({"part": part,
                               "status": "found",
                               "result": {"k": k, "formula": expr}})
            else:
                result.append({"part": part,
                               "status": "not found",
                               "result": ret})
        except KeyboardInterrupt:
            print("User pressed Ctrl+C. Skip all remaining partition.")
            break  # NOTE finally block is executed before break
        except Exception as e:
            result.append({"part": part,
                           "status": "exception",
                           "result": traceback.format_exc()})
            print(e)
        finally:
            data_file = pathlib.Path("out/pre.data")
            data_file.rename(f"out/part-{i:03}-pre.data")
            del teacher
            del learner

    with open(f"out/dtree_synth.{NUM_Y_PARTS}x{NUM_YAW_PARTS}.out.json", "w") as f:
        json.dump(result, f)
if __name__ == "__main__":