Skip to content
Snippets Groups Projects
Commit dc91fffb authored by chsieh16's avatar chsieh16
Browse files

Check safety and use z3 api to (de-)serialize

parent b5a9397e
No related branches found
No related tags found
No related merge requests found
......@@ -127,7 +127,7 @@ def synth_dtree_per_part(
"feature_domain": feature_domain,
"ultimate_bound": ult_bound,
"status": "found",
"result": {"k": k, "formula": z3_expr.sexpr(), "smtlib": smtlib},
"result": {"k": k, "formula": z3_expr.sexpr(), "smtlib": z3_expr.serialize()},
"teacher time": time_info[0],
"learner time": time_info[1]})
else:
......
......@@ -14,7 +14,7 @@ def fp_to_real(val: float):
return z3.fpToReal(z3.FPVal(val, z3.Float64()))
def build_z3_predicate(
def build_aap_predicate(
state_vars: z3.ExprRef,
perc_vars: z3.ExprRef,
aap_json) -> z3.BoolRef:
......@@ -36,13 +36,7 @@ def build_z3_predicate(
else:
assert np.isposinf(ub)
z3_astvec = z3.parse_smt2_string(m["smtlib"])
if len(z3_astvec) == 0:
assert m['formula'] == "true"
dtree = z3.BoolVal(True)
else:
assert len(z3_astvec) == 1
dtree = z3_astvec[0]
dtree = z3.deserialize(m["smtlib"])
aap_formula_list.append(z3.And(*pre_list, dtree))
......@@ -61,9 +55,7 @@ def monitoring(
gt_vars = z3.Reals(gt_var_names)
gte_vars = z3.Reals(gte_var_names)
z3_astvec = z3.parse_smt2_string(aap_pred_smtlib)
assert len(z3_astvec) == 1
aap_pred = z3_astvec[0]
aap_pred = z3.deserialize(aap_pred_smtlib)
# Extract only relevant fields and order the fields correctly
gt_list = gt_trace_arr[['cte', 'psi']].tolist()
......@@ -77,7 +69,7 @@ def monitoring(
z3.simplify(z3.substitute(aap_pred, *subs))
)
bool_list = [z3.is_true(b) for b in bool_list]
return sum(bool_list), len(bool_list)
return bool_list
def main():
......@@ -85,39 +77,51 @@ def main():
AAP_JSON_FILE = "diff0_0/out/dtree_synth.4x10.out.json"
TRACE_PKL_FILE = "data/gem_stanley-straight-1000_traces-500_psicte.pickle"
state_vars = [z3.Real(f"x_{i}") for i in range(3)]
perc_vars = [z3.Real(f"z_{i}") for i in range(2)]
with open(TRACE_PKL_FILE, "rb") as pkl:
trace_pairs = pickle.load(pkl)
with open(AAP_JSON_FILE) as json_fp:
data = json.load(json_fp)
aap_pred = build_z3_predicate(state_vars, perc_vars, data)
print("# Traces:", len(trace_pairs))
print("# States in each trace:", len(trace_pairs[0][1]))
# NOTE temporarily convert state/latent variables to gt variables
gt_var_names = ["d", "psi"]
gte_var_names = ["d_e", "psi_e"]
gt_vars = z3.Reals(gt_var_names)
gte_vars = z3.Reals(gte_var_names)
safety_pred = z3.Abs(gt_vars[0]) <= 1.6
print("Safety predicate:", safety_pred)
in_safe_trace_list = Parallel(NUM_JOBS)(
delayed(monitoring)(
gt_arr, gte_arr, gt_var_names, gte_var_names,
safety_pred.serialize())
for gt_arr, gte_arr in trace_pairs)
num_safe_total_list = [(sum(bool_trace), len(bool_trace)) for bool_trace in in_safe_trace_list]
safe_arr, total_arr = np.array(num_safe_total_list).T
safe_rate = safe_arr / total_arr
print(f"% states in safe: {np.mean(safe_rate)*100:.2f}%")
state_vars = [z3.Real(f"x_{i}") for i in range(3)]
perc_vars = [z3.Real(f"z_{i}") for i in range(2)]
with open(AAP_JSON_FILE) as json_fp:
data = json.load(json_fp)
aap_pred = build_aap_predicate(state_vars, perc_vars, data)
# NOTE temporarily convert state/latent variables to gt variables
subs = [(x, -gt) for x, gt in zip(state_vars[1:], gt_vars)] + \
[(p, gte) for p, gte in zip(perc_vars, gte_vars)]
aap_pred = z3.simplify(z3.substitute(aap_pred, *subs))
# Convert to SMTLib string for pickling and multiprocessing
solver = z3.Solver()
solver.add(aap_pred)
aap_pred_smtlib = solver.to_smt2()
with open(TRACE_PKL_FILE, "rb") as pkl:
trace_pairs = pickle.load(pkl)
num_pass_total_list = Parallel(NUM_JOBS)(
in_aap_trace_list = Parallel(NUM_JOBS)(
delayed(monitoring)(
gt_arr, gte_arr, gt_var_names, gte_var_names,
aap_pred_smtlib)
aap_pred.serialize())
for gt_arr, gte_arr in trace_pairs)
num_pass_total_list = [(sum(bool_trace), len(bool_trace)) for bool_trace in in_aap_trace_list]
pass_arr, total_arr = np.array(num_pass_total_list).T
pass_rate = pass_arr / total_arr
print(np.mean(pass_rate))
print(f"% states in AAP: {np.mean(pass_rate)*100:.2f}%")
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment