Skip to content
Snippets Groups Projects
dtree_synth.py 4.01 KiB
Newer Older
aastorg2's avatar
aastorg2 committed
import itertools
import pickle
from typing import List, Tuple

import numpy as np
import z3
aastorg2's avatar
aastorg2 committed
import csv
from dtree_learner import DTreeLearner as Learner
from dtree_teacher_gem_stanley import DTreeTeacherGEMStanley as Teacher


def load_positive_examples(file_name: str) -> List[Tuple[float, ...]]:
    with open(file_name, "rb") as pickle_file_io:
        pkl_data = pickle.load(pickle_file_io)

    truth_samples_seq = pkl_data["truth_samples"]

    i_th = 0  # select only the i-th partition
    truth_samples_seq = truth_samples_seq[i_th:i_th+1]
    print("Representative point in partition:", truth_samples_seq[0][0])

    truth_samples_seq = [(t, [s for s in raw_samples if not any(np.isnan(s))])
                         for t, raw_samples in truth_samples_seq]
    # Convert from sampled states and percepts to positive examples for learning
    return [
        s for _, samples in truth_samples_seq for s in samples
    ]


def test_synth_dtree():
aastorg2's avatar
aastorg2 committed

    positive_examples = load_positive_examples(
        "data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle")

aastorg2's avatar
aastorg2 committed
    #positive_examples = positive_examples[:20:]  # Select only first few examples
    ex_dim = len(positive_examples[0])
    print("#examples: %d" % len(positive_examples))
    print("Dimension of each example: %d" % ex_dim)
aastorg2's avatar
aastorg2 committed
    
    assert all(len(ex) == ex_dim and not any(np.isnan(ex))
               for ex in positive_examples)

    teacher = Teacher()
    assert teacher.state_dim + teacher.perc_dim == ex_dim
    #  0.0 <= x <= 30.0 and -1.0 <= y <= 0.9 and 0.2 <= theta <= 0.22
    teacher.set_old_state_bound(lb=[0.0, -1.0, 0.2], ub=[30.0, -0.9, 0.22])

aastorg2's avatar
aastorg2 committed
    synth_dtree(positive_examples, teacher, num_max_iterations=50)


def synth_dtree(positive_examples, teacher, num_max_iterations: int = 10):
    learner = Learner(state_dim=teacher.state_dim,
                      perc_dim=teacher.perc_dim, timeout=20000)

    a_mat_0 = np.array([[0., -1., 0.],
                        [0., 0., -1.]])
    b_vec_0 = np.zeros(2)
    learner.set_grammar([(a_mat_0, b_vec_0)])
aastorg2's avatar
aastorg2 committed
    #writing positive examples in csv file
    f = open("positive_sample.csv", 'w')
    data_out = csv.writer(f)
    for ps in positive_examples:
        data_out.writerow(itertools.chain(ps,["true"]))
    f.close()

    learner.add_positive_examples(*positive_examples)

    past_candidate_list = []
    for k in range(num_max_iterations):
        print(f"Iteration {k}:", sep='')
        print("learning ....")
aastorg2's avatar
aastorg2 committed

        candidate_dnf = learner.learn()
        print("done learning")

        if not candidate_dnf:  # learning FAILED
            print("Learning Failed.")
            return

        print(f"candidate DNF: {candidate_dnf}")
aastorg2's avatar
aastorg2 committed
        
        past_candidate_list.append(candidate_dnf)
        # QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
aastorg2's avatar
aastorg2 committed
        print(f"number of paths: {len(candidate_dnf)}")
        negative_examples = []
        for candidate in candidate_dnf:
            result = teacher.check(candidate)
            print(result)
            if result == z3.sat:
                m = teacher.model()
                assert len(m) > 0
                negative_examples.extend(m)
                # TODO check if negative example state is spurious or true courterexample
            elif result == z3.unsat:
                continue
            else:
                print("Reason Unknown", teacher.reason_unknown())
                return past_candidate_list

aastorg2's avatar
aastorg2 committed
        fneg = open("negative_sample.csv", 'w')
        data_out = csv.writer(fneg)
        for pn in negative_examples:
            data_out.writerow(itertools.chain(pn,["true"]))
        fneg.close()
        
        print(f"negative examples: {negative_examples}")
        if len(negative_examples) > 0:
            learner.add_negative_examples(*negative_examples)
        else:
aastorg2's avatar
aastorg2 committed
            file = open('winnerDnf', 'wb')
            pickle.dump(candidate_dnf, file)
            print("we are done!")
            return past_candidate_list

    print("Reached max iteration %d." % num_max_iterations)
    return []


if __name__ == "__main__":
    test_synth_dtree()