dtree_synth.py 8.14 KiB
#!/usr/bin/env python3
import itertools
import json
import matplotlib.pyplot as plt
import pathlib
import pickle
import traceback
from typing import Dict, Hashable, List, Tuple
import numpy as np
import z3
from dtree_learner import DTreeLearner as Learner
from gem_stanley_teacher import DTreeGEMStanleyGurobiTeacher as Teacher
def load_examples(file_name: str, spec) -> Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]]]:
print("Loading examples")
with open(file_name, "rb") as pickle_file_io:
pkl_data = pickle.load(pickle_file_io)
truth_samples_seq = pkl_data["truth_samples"]
# Convert from sampled states and percepts to positive and negative examples for learning
pos_exs, neg_exs, num_excl_exs = [], [], 0
for _, ss in truth_samples_seq:
for s in ss:
if np.any(np.isnan(s)):
num_excl_exs += 1
elif spec(s):
pos_exs.append(s)
else:
neg_exs.append(s)
print("# Excluded NaN examples:", num_excl_exs)
return pos_exs, neg_exs
def synth_dtree(learner: Learner, teacher: Teacher, num_max_iterations: int = 10):
past_candidate_list = []
for k in range(num_max_iterations):
print("="*80)
print(f"Iteration {k}:", sep='')
print("learning ....")
candidate = learner.learn()
print("done learning")
print(f"candidate: {candidate}")
past_candidate_list.append(candidate)
# QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
result = teacher.check(candidate)
print(f"Satisfiability: {result}")
if result == z3.sat:
negative_examples = teacher.model()
assert len(negative_examples) > 0
print(f"negative examples: {negative_examples}")
# assert validate_cexs(teacher.state_dim, teacher.perc_dim, candidate, negative_examples)
learner.add_negative_examples(*negative_examples)
continue
elif result == z3.unsat:
print("we are done!")
return True, (k, z3.simplify(candidate, arith_lhs=True).sexpr())
else:
return False, f"Reason Unknown {teacher.reason_unknown()}"
return False, f"Reached max iteration {num_max_iterations}."
def validate_cexs(state_dim: int, perc_dim: int,
candidate: z3.BoolRef,
cexs: List[Tuple[float]]) -> bool:
spurious_cexs = [cex for cex in cexs
if Teacher.is_spurious_example(state_dim, perc_dim, candidate, cex)]
if not spurious_cexs:
return True
else:
print("Spurious CEXs:", *spurious_cexs, sep='\n')
return False
def search_part(partition, state):
assert len(partition) == len(state)
bounds = []
for sorted_list, v in zip(partition, state):
i = np.searchsorted(sorted_list, v)
if i == 0 or i == len(sorted_list):
return None
bounds.append((sorted_list[i-1], sorted_list[i]))
return tuple(bounds)
def load_partitioned_examples(file_name: str, teacher: Teacher, partition) \
-> Dict[Hashable, Tuple[List[Tuple[float, ...]], List[Tuple[float, ...]], int]]:
print("Loading examples")
with open(file_name, "rb") as pickle_file_io:
pkl_data = pickle.load(pickle_file_io)
truth_samples_seq = pkl_data["truth_samples"]
bound_list = list(list(zip(x_arr[:-1], x_arr[1:])) for x_arr in partition)
ret = {part: [[], [], 0] for part in itertools.product(*bound_list)}
# Convert from sampled states and percepts to positive and negative examples for learning
num_excl_samples = 0
for _, ss in truth_samples_seq:
for s in ss:
state = s[0:teacher.state_dim]
part = search_part(partition, state)
if part is None:
num_excl_samples += 1
continue
if np.any(np.isnan(s)):
ret[part][2] += 1
#TODO note, isin't variable _ being passed to is_positive.
# isint _ ground truth for each s in ss?
elif teacher.is_positive_example(s):
ret[part][0].append(s)
else:
ret[part][1].append(s)
print("# samples not in any selected parts:", num_excl_samples)
return ret
def main(dim:str, bnd_relax:float ):
X_LIM = np.inf
X_ARR = np.array([-X_LIM, X_LIM])
Y_LIM = 1.2
NUM_Y_PARTS = 4
Y_ARR = np.linspace(-Y_LIM, Y_LIM, NUM_Y_PARTS + 1)
YAW_LIM = np.pi / 12
NUM_YAW_PARTS = 10
YAW_ARR = np.linspace(-YAW_LIM, YAW_LIM, NUM_YAW_PARTS + 1)
PARTITION = (X_ARR, Y_ARR, YAW_ARR)
PKL_FILE_PATH = "data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle"
NORM_ORD = 1
NUM_MAX_ITER = 500
#FEATURE_DOMAIN = "diff" # concat #diff is 2D and concat is 4d
FEATURE_DOMAIN = dim
ULT_BOUND = bnd_relax
teacher = Teacher(norm_ord=NORM_ORD, ultimate_bound=ULT_BOUND)
part_to_examples = load_partitioned_examples(
file_name=PKL_FILE_PATH,
teacher=teacher, partition=PARTITION
)
# Print statistics about training data points
print("#"*80)
print("Parts with unsafe data points:")
for i, (part, (safe_dps, unsafe_dps, num_nan)) in enumerate(part_to_examples.items()):
lb, ub = np.asfarray(part).T
lb[2] = np.rad2deg(lb[2])
ub[2] = np.rad2deg(ub[2])
if len(unsafe_dps) > 0:
print(f"Part Index {i}:", f"y in [{lb[1]:.03}, {ub[1]:.03}] (m);", f"θ in [{lb[2]:.03}, {ub[2]:.03}] (deg);",
f"# safe: {len(safe_dps)}", f"# unsafe: {len(unsafe_dps):03}", f"# NaN: {num_nan}")
result = []
for i, (part, (safe_dps, unsafe_dps, num_nan)) in enumerate(part_to_examples.items()):
#if not i == 1:
# continue
print("#"*80)
print(f"# safe: {len(safe_dps)}; "
f"# unsafe: {len(unsafe_dps)}; "
f"# NaN: {num_nan}")
lb, ub = np.asfarray(part).T
# NOTE We create new teacher and learner for each part to avoid solutions from other parts
# TODO incremental solving
teacher = Teacher(norm_ord=NORM_ORD, ultimate_bound=ULT_BOUND)
teacher.set_old_state_bound(lb=lb, ub=ub)
learner = Learner(state_dim=teacher.state_dim,
perc_dim=teacher.perc_dim, timeout=20000)
learner.set_grammar([(Teacher.PERC_GT, np.zeros(2))], FEATURE_DOMAIN)
learner.add_positive_examples(*safe_dps)
learner.add_negative_examples(*unsafe_dps)
try:
found, ret = synth_dtree(learner, teacher,
num_max_iterations=NUM_MAX_ITER)
print(f"Found? {found}")
if found:
k, expr = ret
result.append({"part": part,
"feature_domain": FEATURE_DOMAIN,
"ultimate_bound": ULT_BOUND,
"status": "found",
"result": {"k": k, "formula": expr}})
else:
result.append({"part": part,
"feature_domain": FEATURE_DOMAIN,
"ultimate_bound": ULT_BOUND,
"status": "not found",
"result": ret})
except KeyboardInterrupt:
print("User pressed Ctrl+C. Skip all remaining partition.")
break # NOTE finally block is executed before break
except Exception as e:
result.append({"part": part,
"status": "exception",
"result": traceback.format_exc()})
print(e)
finally:
data_file = pathlib.Path("out/pre.data")
data_file.rename(f"out/part-{i:03}-pre.data")
del teacher
del learner
with open(f"out/dtree_synth.{NUM_Y_PARTS}x{NUM_YAW_PARTS}.out.json", "w") as f:
json.dump(result, f)
if __name__ == "__main__":
#main("diff", 0.0) #EXPERIMENT: 2D, strict
#main("diff", 1.0) #EXPERIMENT: 2D, relaxed
#main("concat", 0.0) #EXPERIMENT: 4D, strict
main("concat", 1.0) #EXPERIMENT: 4D, relaxed