Newer
Older
#!/usr/bin/env python3
import itertools
import json
import pickle
from typing import Dict, Hashable, Literal
import numpy as np
from dtree_synth import DataSet, search_part, synth_dtree_per_part
from dtree_learner import DTreeLearner
from gem_stanley_teacher import DTreeGEMStanleyGurobiStabilityTeacher
from teacher_base import TeacherBase
def load_partitioned_examples(file_name: str, teacher: DTreeGEMStanleyGurobiStabilityTeacher, partition) \
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
-> Dict[Hashable, DataSet]:
print("Loading examples")
with open(file_name, "rb") as pickle_file_io:
pkl_data = pickle.load(pickle_file_io)
truth_samples_seq = pkl_data["truth_samples"]
bound_list = list(list(zip(x_arr[:-1], x_arr[1:])) for x_arr in partition)
ret = {part: DataSet() for part in itertools.product(*bound_list)}
# Convert from sampled states and percepts to positive and negative examples for learning
num_excl_samples = 0
for _, ss in truth_samples_seq:
for s in ss:
state = s[0:teacher.state_dim]
part = search_part(partition, state)
if part is None:
num_excl_samples += 1
continue
if np.any(np.isnan(s)):
ret[part].num_nan_dps += 1
# TODO note, isin't variable _ being passed to is_positive.
# isint _ ground truth for each s in ss?
elif teacher.is_safe_state(s):
ret[part].safe_dps.append(s)
else:
ret[part].unsafe_dps.append(s)
print("# samples not in any selected parts:", num_excl_samples)
return ret
def main(dom: Literal["concat", "diff"], bnd_relax: float):
X_LIM = np.inf
X_ARR = np.array([-X_LIM, X_LIM])
Y_LIM = 1.2
NUM_Y_PARTS = 4
Y_ARR = np.linspace(-Y_LIM, Y_LIM, NUM_Y_PARTS + 1)
YAW_LIM = np.pi / 12
NUM_YAW_PARTS = 10
YAW_ARR = np.linspace(-YAW_LIM, YAW_LIM, NUM_YAW_PARTS + 1)
PARTITION = (X_ARR, Y_ARR, YAW_ARR)
PKL_FILE_PATH = "data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle"
NUM_MAX_ITER = 500
FEATURE_DOMAIN = dom
ULT_BOUND = bnd_relax
return DTreeGEMStanleyGurobiStabilityTeacher(norm_ord=NORM_ORD, ultimate_bound=ULT_BOUND)
def learner_builder(teacher: TeacherBase):
learner = DTreeLearner(state_dim=teacher.state_dim,
perc_dim=teacher.perc_dim, timeout=20000)
learner.set_grammar([(DTreeGEMStanleyGurobiStabilityTeacher.PERC_GT, np.zeros(2))], FEATURE_DOMAIN)
return learner
teacher = teacher_builder()
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
part_to_examples = load_partitioned_examples(
file_name=PKL_FILE_PATH,
teacher=teacher, partition=PARTITION
)
# Print statistics about training data points
print("#"*80)
print("Parts with unsafe data points:")
for i, (part, dataset) in enumerate(part_to_examples.items()):
safe_dps, unsafe_dps, num_nan = dataset.safe_dps, dataset.unsafe_dps, dataset.num_nan_dps
lb, ub = np.asfarray(part).T
lb[2] = np.rad2deg(lb[2])
ub[2] = np.rad2deg(ub[2])
if len(unsafe_dps) > 0:
print(f"Part Index {i}:", f"y in [{lb[1]:.03}, {ub[1]:.03}] (m);", f"θ in [{lb[2]:.03}, {ub[2]:.03}] (deg);",
f"# safe: {len(safe_dps)}", f"# unsafe: {len(unsafe_dps):03}", f"# NaN: {num_nan}")
result = synth_dtree_per_part(
part_to_examples,
teacher_builder,
learner_builder,
num_max_iter=NUM_MAX_ITER,
ult_bound=ULT_BOUND,
feature_domain=FEATURE_DOMAIN
)
with open(f"out/dtree_synth.{NUM_Y_PARTS}x{NUM_YAW_PARTS}.out.json", "w") as f:
json.dump(result, f)
def clean_out_dir(dirname:str):
mkdir = "mkdir "+dirname
cptodir = "cp -r out/ " + dirname
list_cmds = [mkdir, cptodir, "rm out/*.data","rm out/*.json","rm out/pre\.*"]
for cmd in list_cmds:
os.system(cmd)
# main("diff", 0.0) # EXPERIMENT: 2D, strict
# clean_out_dir("diff0_0")
# main("diff", 1.0) # EXPERIMENT: 2D, relaxed
# clean_out_dir("diff1_0")
# main("concat", 0.0) # EXPERIMENT: 4D, strict
# clean_out_dir("concat0_0")
main("concat", 1.0) # EXPERIMENT: 4D, relaxed
clean_out_dir("concat1_0")