Skip to content
Snippets Groups Projects
dtree_synth.py 8.36 KiB
Newer Older
#!/usr/bin/env python3

import itertools
import json
import matplotlib.pyplot as plt
import pickle
import traceback
from typing import Dict, Hashable, List, Tuple

import numpy as np
import z3
from dtree_learner import DTreeLearner as Learner
from dtree_teacher_gem_stanley import DTreeGEMStanleyGurobiTeacher as Teacher
def load_examples(file_name: str, spec) -> Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]]]:
    print("Loading examples")
    with open(file_name, "rb") as pickle_file_io:
        pkl_data = pickle.load(pickle_file_io)

    truth_samples_seq = pkl_data["truth_samples"]

    # Convert from sampled states and percepts to positive and negative examples for learning
    pos_exs, neg_exs, num_excl_exs = [], [], 0
    for _, ss in truth_samples_seq:
        for s in ss:
            if np.any(np.isnan(s)):
                num_excl_exs += 1
            elif spec(s):
                pos_exs.append(s)
            else:
                neg_exs.append(s)
    print("# Exculded NaN examples:", num_excl_exs)
    return pos_exs, neg_exs


def test_synth_dtree():
    """ Test using filtered data where we know an abstraction exists. """

    # Initialize Teacher
chsieh16's avatar
chsieh16 committed
    teacher = Teacher(norm_ord=1, ultimate_bound=0.25)
    teacher.set_old_state_bound(lb=[0.0, -1.0, 0.2], ub=[32.0, -0.9, 0.22])

    init_positive_examples, init_negative_examples = load_examples(
        "data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle",
        teacher.is_positive_example)

    print("# positive examples: %d" % len(init_positive_examples))
    print("# negative examples: %d" % len(init_negative_examples))
    ex_dim = len(init_positive_examples[0])
    print("Dimension of each example: %d" % ex_dim)
    assert all(len(ex) == ex_dim and not any(np.isnan(ex))
               for ex in init_positive_examples)
    assert teacher.state_dim + teacher.perc_dim == ex_dim

    # Initialize Learner
    learner = Learner(state_dim=teacher.state_dim,
                      perc_dim=teacher.perc_dim, timeout=20000)
    a_mat_0 = Teacher.PERC_GT
    b_vec_0 = np.zeros(2)
chsieh16's avatar
chsieh16 committed
    # Let z = [z_0, z_1] = [d, psi]; x = [x_0, x_1, x_2] = [x, y, theta]
    # a_mat_0 @ [x, y, theta] + b_vec_0 = [-y, -theta]
    # z - (a_mat_0 @ x + b_vec_0) = [d, psi] - [-y, -theta] = [d+y, psi+theta] defined as [fvar0_A0, fvar1_A0]

    learner.set_grammar([(a_mat_0, b_vec_0)])
    learner.add_positive_examples(*init_positive_examples)
    learner.add_negative_examples(*init_negative_examples)

    synth_dtree(learner, teacher, num_max_iterations=2000)

def synth_dtree(learner: Learner, teacher: Teacher, num_max_iterations: int = 10):
    past_candidate_list = []
    for k in range(num_max_iterations):
chsieh16's avatar
chsieh16 committed
        print("="*80)
        print(f"Iteration {k}:", sep='')
        print("learning ....")
aastorg2's avatar
aastorg2 committed

        candidate = learner.learn()
        print("done learning")
        print(f"candidate: {candidate}")
        past_candidate_list.append(candidate)
        # QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
        result = teacher.check(candidate)
chsieh16's avatar
chsieh16 committed
        print(f"Satisfiability: {result}")
        if result == z3.sat:
            negative_examples = teacher.model()
            assert len(negative_examples) > 0
            print(f"negative examples: {negative_examples}")

            # assert validate_cexs(teacher.state_dim, teacher.perc_dim, candidate, negative_examples)
            learner.add_negative_examples(*negative_examples)
            continue
        elif result == z3.unsat:
            print("we are done!")
            return True, (k, z3.simplify(candidate, arith_lhs=True).sexpr())
            return False, f"Reason Unknown {teacher.reason_unknown()}"
    return False, f"Reached max iteration {num_max_iterations}."
def validate_cexs(state_dim: int, perc_dim: int,
                  candidate: z3.BoolRef,
                  cexs: List[Tuple[float]]) -> bool:
    spurious_cexs = [cex for cex in cexs
                     if Teacher.is_spurious_example(state_dim, perc_dim, candidate, cex)]
    if not spurious_cexs:
        return True
    else:
        print("Spurious CEXs:", *spurious_cexs, sep='\n')
        return False
def search_part(partition, state):
    assert len(partition) == len(state)
    bounds = []
    for sorted_list, v in zip(partition, state):
        i = np.searchsorted(sorted_list, v)
        if i == 0 or i == len(sorted_list):
            return None
        bounds.append( (sorted_list[i-1], sorted_list[i]) )
    return tuple(bounds)


def load_partitioned_examples(file_name: str, teacher: Teacher, partition) \
        -> Dict[Hashable, Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]], int]]:
    print("Loading examples")
    with open(file_name, "rb") as pickle_file_io:
        pkl_data = pickle.load(pickle_file_io)

    truth_samples_seq = pkl_data["truth_samples"]

    bound_list = list(list(zip(x_arr[:-1], x_arr[1:])) for x_arr in partition)
    ret = {part: [[], [], 0] for part in itertools.product(*bound_list)}
    # Convert from sampled states and percepts to positive and negative examples for learning
    num_excl_samples = 0
    for _, ss in truth_samples_seq:
        for s in ss:
            state = s[0:teacher.state_dim]
            part = search_part(partition, state)
            if part is None:
                num_excl_samples += 1
                continue
            if np.any(np.isnan(s)):
                ret[part][2] += 1
            elif teacher.is_positive_example(s):
                ret[part][0].append(s)
            else:
                ret[part][1].append(s)
    print("# excluded samples:", num_excl_samples)
    return ret


def main():
    X_LIM = np.inf
    X_ARR = np.array([-X_LIM, X_LIM])

    Y_LIM = 1.2
    NUM_Y_PARTS = 4
    Y_ARR = np.linspace(-Y_LIM, Y_LIM, NUM_Y_PARTS + 1)

    YAW_LIM = np.pi / 12
    NUM_YAW_PARTS = 10
    YAW_ARR = np.linspace(-YAW_LIM, YAW_LIM, NUM_YAW_PARTS + 1)

    PARTITION = (X_ARR, Y_ARR, YAW_ARR)

    PKL_FILE_PATH = "data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle"
    NORM_ORD = 1
    NUM_MAX_ITER = 500

    teacher = Teacher(norm_ord=NORM_ORD)
    part_to_examples = load_partitioned_examples(
        file_name=PKL_FILE_PATH,
        teacher=teacher, partition=PARTITION
    )

    result = []
    for part, (pos_exs, neg_exs, num_nan) in part_to_examples.items():
        print("#"*80)
        print(f"# positive: {len(pos_exs)}; "
              f"# negative: {len(neg_exs)}; "
              f"# NaN: {num_nan}")
        lb, ub = np.asfarray(part).T

        # XXX Create new teacher and learner for each part to avoid solutions from other parts
        # TODO incremental solving
        teacher = Teacher(norm_ord=NORM_ORD)
        teacher.set_old_state_bound(lb=lb, ub=ub)

        learner = Learner(state_dim=teacher.state_dim,
                          perc_dim=teacher.perc_dim, timeout=20000)
        learner.set_grammar([(Teacher.PERC_GT, np.zeros(2))])

        if pos_exs:
            pos_fv_arr = np.asfarray([learner._s2f_func(exs) for exs in pos_exs])
            plt.scatter(pos_fv_arr[:, 0], pos_fv_arr[:, 1], c="g", marker="o")
        if neg_exs:
            neg_fv_arr = np.asfarray([learner._s2f_func(exs) for exs in neg_exs])
            plt.scatter(neg_fv_arr[:, 0], neg_fv_arr[:, 1], c="r", marker="x")
        plt.show()

        continue  # XXX Temporary skip learning and only plot feature vectors

        learner.add_positive_examples(*pos_exs)
        learner.add_negative_examples(*neg_exs)
        try:
            found, ret = synth_dtree(learner, teacher,
                                     num_max_iterations=NUM_MAX_ITER)
            print(f"Found? {found}")
            if found:
                k, expr = ret
                result.append({"part": part,
                               "status": "found",
                               "result": {"k": k, "formula": expr}})
            else:
                result.append({"part": part,
                               "status": "not found",
                               "result": ret})
        except Exception as e:
            result.append({"part": part,
                           "status": "exception",
                           "result": traceback.format_exc()})
            print(e)
        finally:
            del teacher
            del learner

    with open(f"out/dtree_synth.{NUM_Y_PARTS}x{NUM_YAW_PARTS}.out.json", "w") as f:
        json.dump(result, f)
if __name__ == "__main__":
    # test_synth_dtree()
    main()