Skip to content
Snippets Groups Projects
Commit 6497d7ad authored by chsieh16's avatar chsieh16
Browse files

Print out parts with exceptions

parent 8458576f
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
import json
import numpy as np
import pickle
import z3
STATE_DIM = 3
PERC_DIM = 2
filename = "dtree_synth.4x10.out.json"
with open(filename, "r") as f:
data = json.load(f)
found_dtree = [dict(part=entry["part"], **entry["result"]) for entry in data if entry["status"] == "found"]
print(len(found_dtree))
found, not_found, spur = 0, 0, 0
repeated_neg = 0
other = 0
for entry in data:
for i, entry in enumerate(data):
if entry["status"] == "found":
found += 1
elif entry["status"] == "not found":
print(f"Not Found for Partition {i}: {entry['part']}")
not_found +=1
elif entry["status"] == "exception":
print(f"Exception for Partition {i}: {entry['part']}")
if "spurious cexs" in entry["result"]:
spur += 1
elif "repeated" in entry["result"]:
repeated_neg += 1
else:
print(entry["result"])
other += 1
print(found, not_found, spur, repeated_neg, other)
```
%% Cell type:code id: tags:
``` python
def z3_float64_const_to_real(v: float) -> z3.RatNumRef:
return z3.simplify(
z3.fpToReal(z3.FPVal(v, z3.Float64()))
)
def in_part(state_arr, part_arr):
assert part_arr.shape == (len(state_arr), 2)
lb_arr, ub_arr = part_arr.T
return np.all(lb_arr <= state_arr) and np.all(state_arr <= ub_arr)
def calc_precision(part, z3_expr) -> float:
def in_z3_expr(sample, z3_expr) -> bool:
assert len(sample) == STATE_DIM + PERC_DIM
state_subs_map = [(z3.Real(f"x_{i}"), z3_float64_const_to_real(sample[i])) for i in range(STATE_DIM)]
perc_subs_map = [(z3.Real(f"z_{i}"), z3_float64_const_to_real(sample[i+STATE_DIM])) for i in range(PERC_DIM)]
sub_map = state_subs_map + perc_subs_map
val = z3.simplify(z3.substitute(z3_expr, *sub_map))
assert z3.is_bool(val)
if z3.is_false(val):
return False
elif z3.is_true(val):
return True
else:
raise RuntimeError(f"Cannot validate negative example {sample} by substitution")
pkl_name = "../data/800_truths-uniform_partition_4x20-1.2m-pi_12-one_straight_road-2021-10-27-08-49-17.bag.pickle"
with open(pkl_name, "rb") as f:
pkl_data = pickle.load(f)
truth_samples_seq = pkl_data["truth_samples"]
part_arr = np.asfarray(part)
num_pos, num_neg, num_nan = 0, 0, 0
for _, ss in truth_samples_seq:
for s in ss:
state_arr = np.asfarray(s[0:3])
if not in_part(state_arr, part_arr):
continue
# else:
if np.any(np.isnan(s)):
num_nan += 1
elif in_z3_expr(s, z3_expr):
num_pos += 1
else:
num_neg += 1
return num_pos, num_neg, num_nan
```
%% Cell type:code id: tags:
``` python
def visitor(e, seen):
if e in seen:
return
seen[e] = True
yield e
if z3.is_app(e):
for ch in e.children():
for e in visitor(ch, seen):
yield e
return
if z3.is_quantifier(e):
for e in visitor(e.body(), seen):
yield e
return
```
%% Cell type:code id: tags:
``` python
for result in found_dtree:
print(result['part'])
decls = {vname: z3.Real(vname) for vname in ["x_0", "x_1", "x_2", "z_0", "z_1"]}
smt2_str = f"(assert {result['formula']})"
z3_assertions = z3.parse_smt2_string(smt2_str, decls=decls)
z3_expr:z3.ExprRef = z3_assertions[0]
# print("#Atomic Predicates:", sum(z3.is_le(e) or z3.is_ge(e) for e in visitor(z3_expr, {})))
# print(z3_expr)
# Calculate the number of paths on a binary tree by adding one more path
# when there is an ite or a disjunction (due to simplification on ite).
# FIXME does not work if an ite expression is a common sub-expression of two paths.
num_paths = 1
for e in visitor(z3_expr, {}):
if z3.is_or(e) or z3.is_app_of(e, z3.Z3_OP_ITE):
num_paths += 1
print("#Paths:", num_paths)
num_pos, num_neg, num_nan = calc_precision(result['part'], z3_expr)
print(f"pos: {num_pos}; neg: {num_neg}; nan: {num_nan}")
print("precision (pos/(pos+neg)):", num_pos / (num_pos + num_neg) )
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment