Skip to content
Snippets Groups Projects
dtree_synth.py 4.38 KiB
Newer Older
  • Learn to ignore specific revisions
  • import pickle
    from typing import List, Tuple
    
    import numpy as np
    import z3
    from dtree_learner import DTreeLearner as Learner
    
    from dtree_teacher_gem_stanley import DTreeGEMStanleyGurobiTeacher as Teacher
    
    def load_examples(file_name: str, spec) -> Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]]]:
        print("Loading examples")
    
        with open(file_name, "rb") as pickle_file_io:
            pkl_data = pickle.load(pickle_file_io)
    
        truth_samples_seq = pkl_data["truth_samples"]
    
    
        # Convert from sampled states and percepts to positive and negative examples for learning
        pos_exs, neg_exs, num_excl_exs = [], [], 0
        for _, ss in truth_samples_seq:
            for s in ss:
                ret = spec(s)
                if np.any(np.isnan(s)) or ret is None:
                    num_excl_exs += 1
                elif ret:
                    pos_exs.append(s)
                else:
                    neg_exs.append(s)
        print("# Exculded examples:", num_excl_exs)
        return pos_exs, neg_exs
    
    
    
    def test_synth_dtree():
    
        teacher = Teacher(norm_ord=1)
        # 0.0 <= x <= 32.0 and -1.0 <= y <= -0.9 and 0.2 <= theta <= 0.22
        teacher.set_old_state_bound(lb=[0.0, -1.0, 0.2], ub=[32.0, -0.9, 0.22])
        # teacher.set_old_state_bound(lb=[0.0, -0.9, 2*np.pi/60], ub=[32.0, -0.6, 3*np.pi/60])
        # teacher.set_old_state_bound(lb=[0.0, 0.3, 1*np.pi/60], ub=[32.0, 0.9, 5*np.pi/60])
    
        positive_examples, negative_examples = load_examples(
            # "data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle",
            "data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle",
            teacher.is_positive_example)
    
        print("# positive examples: %d" % len(positive_examples))
        print("# negative examples: %d" % len(negative_examples))
    
        ex_dim = len(positive_examples[0])
        print("Dimension of each example: %d" % ex_dim)
    
        assert all(len(ex) == ex_dim and not any(np.isnan(ex))
                   for ex in positive_examples)
        assert teacher.state_dim + teacher.perc_dim == ex_dim
    
    
        synth_dtree(positive_examples, negative_examples, teacher, num_max_iterations=2000)
    
    def synth_dtree(positive_examples, negative_examples, teacher, num_max_iterations: int = 10):
    
        learner = Learner(state_dim=teacher.state_dim,
                          perc_dim=teacher.perc_dim, timeout=20000)
    
        a_mat_0 = np.array([[0., -1., 0.],
                            [0., 0., -1.]])
        b_vec_0 = np.zeros(2)
    
    chsieh16's avatar
    chsieh16 committed
        # Let z = [z_0, z_1] = [d, psi]; x = [x_0, x_1, x_2] = [x, y, theta]
        # a_mat_0 @ [x, y, theta] + b_vec_0 = [-y, -theta]
        # z - (a_mat_0 @ x + b_vec_0) = [d, psi] - [-y, -theta] = [d+y, psi+theta] defined as [fvar0_A0, fvar1_A0]
    
    
        learner.set_grammar([(a_mat_0, b_vec_0)])
        learner.add_positive_examples(*positive_examples)
    
        learner.add_negative_examples(*negative_examples)
    
    
        past_candidate_list = []
        for k in range(num_max_iterations):
    
    chsieh16's avatar
    chsieh16 committed
            print("="*80)
    
            print(f"Iteration {k}:", sep='')
            print("learning ....")
    
    aastorg2's avatar
    aastorg2 committed
    
    
            candidate = learner.learn()
    
            print("done learning")
    
            print(f"candidate: {candidate}")
    
            past_candidate_list.append(candidate)
    
            # QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
    
            result = teacher.check(candidate)
    
    chsieh16's avatar
    chsieh16 committed
            print(f"Satisfiability: {result}")
    
            if result == z3.sat:
                negative_examples = teacher.model()
    
                # assert len(negative_examples) > 0
    
                print(f"negative examples: {negative_examples}")
    
    
                # assert validate_cexs(teacher.state_dim, teacher.perc_dim, candidate, negative_examples)
    
                learner.add_negative_examples(*negative_examples)
    
                continue
            elif result == z3.unsat:
    
                print("we are done!")
    
                print(f"Simplified candidate: {z3.simplify(candidate)}")
    
                return past_candidate_list
    
            else:
                print("Reason Unknown", teacher.reason_unknown())
                return past_candidate_list
    
    
        print("Reached max iteration %d." % num_max_iterations)
    
    
    def validate_cexs(state_dim: int, perc_dim: int,
    
                      candidate: z3.BoolRef,
    
                      cexs: List[Tuple[float]]) -> bool:
    
        spurious_cexs = [cex for cex in cexs
                         if Teacher.is_spurious_example(state_dim, perc_dim, candidate, cex)]
    
        if not spurious_cexs:
            return True
        else:
            print("Spurious CEXs:", *spurious_cexs, sep='\n')
            return False
    
    if __name__ == "__main__":
        test_synth_dtree()