Newer
Older
#!/usr/bin/env python3
import itertools
import json
import matplotlib.pyplot as plt
import traceback
from typing import Dict, Hashable, List, Tuple
import numpy as np
import z3
from dtree_learner import DTreeLearner as Learner
from dtree_teacher_gem_stanley import DTreeGEMStanleyGurobiTeacher as Teacher
def load_examples(file_name: str, spec) -> Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]]]:
print("Loading examples")
with open(file_name, "rb") as pickle_file_io:
pkl_data = pickle.load(pickle_file_io)
truth_samples_seq = pkl_data["truth_samples"]
# Convert from sampled states and percepts to positive and negative examples for learning
pos_exs, neg_exs, num_excl_exs = [], [], 0
for _, ss in truth_samples_seq:
for s in ss:
pos_exs.append(s)
else:
neg_exs.append(s)
def synth_dtree(learner: Learner, teacher: Teacher, num_max_iterations: int = 10):
past_candidate_list = []
for k in range(num_max_iterations):
print(f"Iteration {k}:", sep='')
print("learning ....")
print(f"candidate: {candidate}")
past_candidate_list.append(candidate)
# QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
result = teacher.check(candidate)
if result == z3.sat:
negative_examples = teacher.model()
print(f"negative examples: {negative_examples}")
# assert validate_cexs(teacher.state_dim, teacher.perc_dim, candidate, negative_examples)
learner.add_negative_examples(*negative_examples)
continue
elif result == z3.unsat:
return True, (k, z3.simplify(candidate, arith_lhs=True).sexpr())
return False, f"Reason Unknown {teacher.reason_unknown()}"
return False, f"Reached max iteration {num_max_iterations}."
def validate_cexs(state_dim: int, perc_dim: int,
cexs: List[Tuple[float]]) -> bool:
spurious_cexs = [cex for cex in cexs
if Teacher.is_spurious_example(state_dim, perc_dim, candidate, cex)]
if not spurious_cexs:
return True
else:
print("Spurious CEXs:", *spurious_cexs, sep='\n')
return False
def search_part(partition, state):
assert len(partition) == len(state)
bounds = []
for sorted_list, v in zip(partition, state):
i = np.searchsorted(sorted_list, v)
if i == 0 or i == len(sorted_list):
return None
bounds.append((sorted_list[i-1], sorted_list[i]))
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
return tuple(bounds)
def load_partitioned_examples(file_name: str, teacher: Teacher, partition) \
-> Dict[Hashable, Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]], int]]:
print("Loading examples")
with open(file_name, "rb") as pickle_file_io:
pkl_data = pickle.load(pickle_file_io)
truth_samples_seq = pkl_data["truth_samples"]
bound_list = list(list(zip(x_arr[:-1], x_arr[1:])) for x_arr in partition)
ret = {part: [[], [], 0] for part in itertools.product(*bound_list)}
# Convert from sampled states and percepts to positive and negative examples for learning
num_excl_samples = 0
for _, ss in truth_samples_seq:
for s in ss:
state = s[0:teacher.state_dim]
part = search_part(partition, state)
if part is None:
num_excl_samples += 1
continue
if np.any(np.isnan(s)):
ret[part][2] += 1
elif teacher.is_positive_example(s):
ret[part][0].append(s)
else:
ret[part][1].append(s)
print("# samples not in any selected parts:", num_excl_samples)
return ret
def main():
X_LIM = np.inf
X_ARR = np.array([-X_LIM, X_LIM])
Y_LIM = 1.2
NUM_Y_PARTS = 4
Y_ARR = np.linspace(-Y_LIM, Y_LIM, NUM_Y_PARTS + 1)
YAW_LIM = np.pi / 12
NUM_YAW_PARTS = 10
YAW_ARR = np.linspace(-YAW_LIM, YAW_LIM, NUM_YAW_PARTS + 1)
PARTITION = (X_ARR, Y_ARR, YAW_ARR)
PKL_FILE_PATH = "data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle"
NORM_ORD = 1
NUM_MAX_ITER = 500
FEATURE_DOMAIN = "diff"
ULT_BOUND = 0.0
teacher = Teacher(norm_ord=NORM_ORD, ultimate_bound=ULT_BOUND)
part_to_examples = load_partitioned_examples(
file_name=PKL_FILE_PATH,
teacher=teacher, partition=PARTITION
)
# Print statistics about training data points
print("#"*80)
print("Parts with unsafe data points:")
for i, (part, (safe_dps, unsafe_dps, num_nan)) in enumerate(part_to_examples.items()):
lb, ub = np.asfarray(part).T
lb[2] = np.rad2deg(lb[2])
ub[2] = np.rad2deg(ub[2])
if len(unsafe_dps) > 0:
print(f"Part Index {i}:", f"y in [{lb[1]:.03}, {ub[1]:.03}] (m);", f"θ in [{lb[2]:.03}, {ub[2]:.03}] (deg);",
f"# safe: {len(safe_dps)}", f"# unsafe: {len(unsafe_dps):03}", f"# NaN: {num_nan}")
for i, (part, (safe_dps, unsafe_dps, num_nan)) in enumerate(part_to_examples.items()):
print(f"# safe: {len(safe_dps)}; "
f"# unsafe: {len(unsafe_dps)}; "
f"# NaN: {num_nan}")
lb, ub = np.asfarray(part).T
# NOTE We create new teacher and learner for each part to avoid solutions from other parts
teacher = Teacher(norm_ord=NORM_ORD, ultimate_bound=ULT_BOUND)
teacher.set_old_state_bound(lb=lb, ub=ub)
learner = Learner(state_dim=teacher.state_dim,
perc_dim=teacher.perc_dim, timeout=20000)
learner.set_grammar([(Teacher.PERC_GT, np.zeros(2))], FEATURE_DOMAIN)
learner.add_positive_examples(*safe_dps)
learner.add_negative_examples(*unsafe_dps)
try:
found, ret = synth_dtree(learner, teacher,
num_max_iterations=NUM_MAX_ITER)
print(f"Found? {found}")
if found:
k, expr = ret
result.append({"part": part,
"feature_domain": FEATURE_DOMAIN,
"ultimate_bound": ULT_BOUND,
"status": "found",
"result": {"k": k, "formula": expr}})
else:
result.append({"part": part,
"feature_domain": FEATURE_DOMAIN,
"ultimate_bound": ULT_BOUND,
"status": "not found",
"result": ret})
except KeyboardInterrupt:
print("User pressed Ctrl+C. Skip all remaining partition.")
break # NOTE finally block is executed before break
except Exception as e:
result.append({"part": part,
"status": "exception",
"result": traceback.format_exc()})
print(e)
finally:
data_file = pathlib.Path("out/pre.data")
data_file.rename(f"out/part-{i:03}-pre.data")
with open(f"out/dtree_synth.{NUM_Y_PARTS}x{NUM_YAW_PARTS}.out.json", "w") as f:
json.dump(result, f)