Skip to content
Snippets Groups Projects
Commit b49cb21d authored by aastorg2's avatar aastorg2
Browse files

Merge branch 'main' of gitlab.engr.illinois.edu:chsieh16/cs598mp-fall2021-proj into main

parents 5f71a28d 6497d7ad
No related branches found
No related tags found
No related merge requests found
......@@ -31,46 +31,10 @@ def load_examples(file_name: str, spec) -> Tuple[List[Tuple[float, ...]], List[T
pos_exs.append(s)
else:
neg_exs.append(s)
print("# Exculded NaN examples:", num_excl_exs)
print("# Excluded NaN examples:", num_excl_exs)
return pos_exs, neg_exs
def test_synth_dtree():
""" Test using filtered data where we know an abstraction exists. """
# Initialize Teacher
teacher = Teacher(norm_ord=1, ultimate_bound=0.25)
teacher.set_old_state_bound(lb=[0.0, -1.0, 0.2], ub=[32.0, -0.9, 0.22])
init_positive_examples, init_negative_examples = load_examples(
"data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle",
teacher.is_positive_example)
print("# positive examples: %d" % len(init_positive_examples))
print("# negative examples: %d" % len(init_negative_examples))
ex_dim = len(init_positive_examples[0])
print("Dimension of each example: %d" % ex_dim)
assert all(len(ex) == ex_dim and not any(np.isnan(ex))
for ex in init_positive_examples)
assert teacher.state_dim + teacher.perc_dim == ex_dim
# Initialize Learner
learner = Learner(state_dim=teacher.state_dim,
perc_dim=teacher.perc_dim, timeout=20000)
a_mat_0 = Teacher.PERC_GT
b_vec_0 = np.zeros(2)
# Let z = [z_0, z_1] = [d, psi]; x = [x_0, x_1, x_2] = [x, y, theta]
# a_mat_0 @ [x, y, theta] + b_vec_0 = [-y, -theta]
# z - (a_mat_0 @ x + b_vec_0) = [d, psi] - [-y, -theta] = [d+y, psi+theta] defined as [fvar0_A0, fvar1_A0]
learner.set_grammar([(a_mat_0, b_vec_0)])
learner.add_positive_examples(*init_positive_examples)
learner.add_negative_examples(*init_negative_examples)
synth_dtree(learner, teacher, num_max_iterations=2000)
def synth_dtree(learner: Learner, teacher: Teacher, num_max_iterations: int = 10):
past_candidate_list = []
for k in range(num_max_iterations):
......@@ -153,7 +117,7 @@ def load_partitioned_examples(file_name: str, teacher: Teacher, partition) \
ret[part][0].append(s)
else:
ret[part][1].append(s)
print("# excluded samples:", num_excl_samples)
print("# samples not in any selected parts:", num_excl_samples)
return ret
......@@ -174,32 +138,46 @@ def main():
PKL_FILE_PATH = "data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle"
NORM_ORD = 1
NUM_MAX_ITER = 500
FEATURE_DOMAIN = "diff"
ULT_BOUND = 0.0
teacher = Teacher(norm_ord=NORM_ORD)
teacher = Teacher(norm_ord=NORM_ORD, ultimate_bound=ULT_BOUND)
part_to_examples = load_partitioned_examples(
file_name=PKL_FILE_PATH,
teacher=teacher, partition=PARTITION
)
# Print statistics about training data points
print("#"*80)
print("Parts with unsafe data points:")
for i, (part, (safe_dps, unsafe_dps, num_nan)) in enumerate(part_to_examples.items()):
lb, ub = np.asfarray(part).T
lb[2] = np.rad2deg(lb[2])
ub[2] = np.rad2deg(ub[2])
if len(unsafe_dps) > 0:
print(f"Part Index {i}:", f"y in [{lb[1]:.03}, {ub[1]:.03}] (m);", f"θ in [{lb[2]:.03}, {ub[2]:.03}] (deg);",
f"# safe: {len(safe_dps)}", f"# unsafe: {len(unsafe_dps):03}", f"# NaN: {num_nan}")
result = []
for i, (part, (pos_exs, neg_exs, num_nan)) in enumerate(part_to_examples.items()):
for i, (part, (safe_dps, unsafe_dps, num_nan)) in enumerate(part_to_examples.items()):
print("#"*80)
print(f"# positive: {len(pos_exs)}; "
f"# negative: {len(neg_exs)}; "
print(f"# safe: {len(safe_dps)}; "
f"# unsafe: {len(unsafe_dps)}; "
f"# NaN: {num_nan}")
lb, ub = np.asfarray(part).T
# XXX Create new teacher and learner for each part to avoid solutions from other parts
# NOTE We create new teacher and learner for each part to avoid solutions from other parts
# TODO incremental solving
teacher = Teacher(norm_ord=NORM_ORD)
teacher = Teacher(norm_ord=NORM_ORD, ultimate_bound=ULT_BOUND)
teacher.set_old_state_bound(lb=lb, ub=ub)
learner = Learner(state_dim=teacher.state_dim,
perc_dim=teacher.perc_dim, timeout=20000)
learner.set_grammar([(Teacher.PERC_GT, np.zeros(2))])
learner.set_grammar([(Teacher.PERC_GT, np.zeros(2))], FEATURE_DOMAIN)
learner.add_positive_examples(*pos_exs)
learner.add_negative_examples(*neg_exs)
learner.add_positive_examples(*safe_dps)
learner.add_negative_examples(*unsafe_dps)
try:
found, ret = synth_dtree(learner, teacher,
num_max_iterations=NUM_MAX_ITER)
......@@ -207,12 +185,19 @@ def main():
if found:
k, expr = ret
result.append({"part": part,
"feature_domain": FEATURE_DOMAIN,
"ultimate_bound": ULT_BOUND,
"status": "found",
"result": {"k": k, "formula": expr}})
else:
result.append({"part": part,
"feature_domain": FEATURE_DOMAIN,
"ultimate_bound": ULT_BOUND,
"status": "not found",
"result": ret})
except KeyboardInterrupt:
print("User pressed Ctrl+C. Skip all remaining partition.")
break # NOTE finally block is executed before break
except Exception as e:
result.append({"part": part,
"status": "exception",
......@@ -229,5 +214,4 @@ def main():
if __name__ == "__main__":
# test_synth_dtree()
main()
......@@ -139,11 +139,12 @@ class DTreeGEMStanleyGurobiTeacher(GEMStanleyGurobiTeacher):
if filtered_cex_list:
self._cexs.extend(filtered_cex_list)
else:
# raise RuntimeError(f"Only found spurious cexs {cex_list} for the conjuct {conjunct}.")
# raise RuntimeError(f"Only found spurious cexs {cex_list} for the conjunct {conjunct}.")
pass
elif self._gp_model.status == gp.GRB.INFEASIBLE:
continue
elif self._gp_model.status == gp.GRB.INTERRUPTED:
raise KeyboardInterrupt
else:
return z3.unknown
print("Done")
......
%% Cell type:code id: tags:
``` python
import json
import numpy as np
import pickle
import z3
STATE_DIM = 3
PERC_DIM = 2
filename = "dtree_synth.4x10.out.json"
with open(filename, "r") as f:
data = json.load(f)
found_dtree = [dict(part=entry["part"], **entry["result"]) for entry in data if entry["status"] == "found"]
print(len(found_dtree))
found, not_found, spur = 0, 0, 0
repeated_neg = 0
other = 0
for entry in data:
for i, entry in enumerate(data):
if entry["status"] == "found":
found += 1
elif entry["status"] == "not found":
print(f"Not Found for Partition {i}: {entry['part']}")
not_found +=1
elif entry["status"] == "exception":
print(f"Exception for Partition {i}: {entry['part']}")
if "spurious cexs" in entry["result"]:
spur += 1
elif "repeated" in entry["result"]:
repeated_neg += 1
else:
print(entry["result"])
other += 1
print(found, not_found, spur, repeated_neg, other)
```
%% Cell type:code id: tags:
``` python
def z3_float64_const_to_real(v: float) -> z3.RatNumRef:
return z3.simplify(
z3.fpToReal(z3.FPVal(v, z3.Float64()))
)
def in_part(state_arr, part_arr):
assert part_arr.shape == (len(state_arr), 2)
lb_arr, ub_arr = part_arr.T
return np.all(lb_arr <= state_arr) and np.all(state_arr <= ub_arr)
def calc_precision(part, z3_expr) -> float:
def in_z3_expr(sample, z3_expr) -> bool:
assert len(sample) == STATE_DIM + PERC_DIM
state_subs_map = [(z3.Real(f"x_{i}"), z3_float64_const_to_real(sample[i])) for i in range(STATE_DIM)]
perc_subs_map = [(z3.Real(f"z_{i}"), z3_float64_const_to_real(sample[i+STATE_DIM])) for i in range(PERC_DIM)]
sub_map = state_subs_map + perc_subs_map
val = z3.simplify(z3.substitute(z3_expr, *sub_map))
assert z3.is_bool(val)
if z3.is_false(val):
return False
elif z3.is_true(val):
return True
else:
raise RuntimeError(f"Cannot validate negative example {sample} by substitution")
pkl_name = "../data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle"
with open(pkl_name, "rb") as f:
pkl_data = pickle.load(f)
truth_samples_seq = pkl_data["truth_samples"]
part_arr = np.asfarray(part)
num_pos, num_neg, num_nan = 0, 0, 0
for _, ss in truth_samples_seq:
for s in ss:
state_arr = np.asfarray(s[0:3])
if not in_part(state_arr, part_arr):
continue
# else:
if np.any(np.isnan(s)):
num_nan += 1
elif in_z3_expr(s, z3_expr):
num_pos += 1
else:
num_neg += 1
return num_pos, num_neg, num_nan
```
%% Cell type:code id: tags:
``` python
def visitor(e, seen):
if e in seen:
return
seen[e] = True
yield e
if z3.is_app(e):
for ch in e.children():
for e in visitor(ch, seen):
yield e
return
if z3.is_quantifier(e):
for e in visitor(e.body(), seen):
yield e
return
```
%% Cell type:code id: tags:
``` python
for result in found_dtree:
print(result['part'])
decls = {vname: z3.Real(vname) for vname in ["x_0", "x_1", "x_2", "z_0", "z_1"]}
smt2_str = f"(assert {result['formula']})"
z3_assertions = z3.parse_smt2_string(smt2_str, decls=decls)
z3_expr:z3.ExprRef = z3_assertions[0]
# print("#Atomic Predicates:", sum(z3.is_le(e) or z3.is_ge(e) for e in visitor(z3_expr, {})))
# print(z3_expr)
# Calculate the number of paths on a binary tree by adding one more path
# when there is an ite or a disjunction (due to simplification on ite).
# FIXME does not work if an ite expression is a common sub-expression of two paths.
num_paths = 1
for e in visitor(z3_expr, {}):
if z3.is_or(e) or z3.is_app_of(e, z3.Z3_OP_ITE):
num_paths += 1
print("#Paths:", num_paths)
num_pos, num_neg, num_nan = calc_precision(result['part'], z3_expr)
print(f"pos: {num_pos}; neg: {num_neg}; nan: {num_nan}")
print("precision (pos/(pos+neg)):", num_pos / (num_pos + num_neg) )
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment