Newer
Older
#!/usr/bin/env python3
import json
import pickle
from typing import List, Tuple
from joblib import Parallel, delayed
import numpy as np
import z3
from gem_stanley_teacher import DTreeGEMStanleyGurobiStabilityTeacher
def fp_to_real(val: float):
assert np.isfinite(val)
return z3.fpToReal(z3.FPVal(val, z3.Float64()))
state_vars: z3.ExprRef,
perc_vars: z3.ExprRef,
found_dtree = [dict(part=entry["part"], **entry["result"])
for entry in aap_json if entry["status"] == "found"]
aap_formula_list = []
for m in found_dtree:
assert len(state_vars) == len(m['part'])
pre_list = []
for x, (lb, ub) in zip(state_vars, m['part']):
if np.isfinite(lb):
pre_list.append(x >= fp_to_real(lb))
else:
assert np.isneginf(lb)
if np.isfinite(ub):
pre_list.append(x <= fp_to_real(ub))
else:
assert np.isposinf(ub)
dtree = z3.deserialize(m["smtlib"])
aap_formula_list.append(z3.Implies(z3.And(*pre_list), dtree))
otherwise_list.append(z3.Not(z3.And(*pre_list)))
aap_formula_list.append(
z3.Implies(
z3.And(*otherwise_list), z3.BoolVal(False)
)
)
return z3.And(aap_formula_list)
gt_list: List[Tuple[float, ...]],
gte_list: List[Tuple[float, ...]],
gt_var_names: List[str],
gte_var_names: List[str],
aap_pred_smtlib: str):
assert len(gt_list) == len(gte_list)
gt_vars = z3.Reals(gt_var_names)
gte_vars = z3.Reals(gte_var_names)
aap_pred = z3.deserialize(aap_pred_smtlib)
bool_list = [] # type: List[z3.BoolRef]
for gt, gte in zip(gt_list, gte_list):
subs = [(var, fp_to_real(val)) for var, val in zip(gt_vars, gt)] \
+ [(var, fp_to_real(val)) for var, val in zip(gte_vars, gte)]
bool_list.append(
z3.simplify(z3.substitute(aap_pred, *subs))
)
bool_list = [z3.is_true(b) for b in bool_list]
def parallel_monitoring(trace_pairs, gt_var_names, gte_var_names, pred: z3.BoolRef, num_jobs: int):
in_pred_trace_list = Parallel(num_jobs)(
delayed(monitoring)(
gt_list, gte_list, gt_var_names, gte_var_names,
pred.serialize())
for gt_list, gte_list in trace_pairs)
num_safe_total_list = [(sum(bool_trace), len(bool_trace)) for bool_trace in in_pred_trace_list]
in_pred_arr, total_arr = np.array(num_safe_total_list).T
pass_rate = in_pred_arr / total_arr
print(f"Min % positive states: {np.min(pass_rate)*100:.2f}%")
print(f"Avg % positive states: {np.mean(pass_rate)*100:.2f}%")
print(f"Max % positive states: {np.max(pass_rate)*100:.2f}%")
return in_pred_trace_list
def check_gt_in_aap(
aap_pred: z3.BoolRef,
gt_vars: List[z3.ExprRef],
gte_vars: List[z3.ExprRef]) -> None:
"""
Check if the AAP predicate holds when the ground truth equals estimate.
"""
assert z3.is_and(aap_pred)
print("="*80)
vals = gt_vars
sub_vals = [(gt, val) for gt, val in zip(gt_vars, vals)] + \
[(gte, val) for gte, val in zip(gte_vars, vals)]
for i, pred in enumerate(aap_pred.children()):
gt_in_aap = z3.simplify(z3.substitute(pred, *sub_vals))
if z3.is_true(gt_in_aap):
continue
s = z3.Solver()
s.add(z3.Not(gt_in_aap))
res = s.check()
if res != z3.unsat:
print(f"Part {i}")
print(s.model())
AAP_JSON_FILE = "diff0_0/out/dtree_synth.4x10.out.json"
TRACE_PKL_FILE = "data/gem_stanley-straight-1000_traces-500_psicte.pickle"
with open(TRACE_PKL_FILE, "rb") as pkl:
trace_arr_pairs = pickle.load(pkl)
trace_pairs = []
for gt_trace_arr, gte_trace_arr in trace_arr_pairs:
# Extract only relevant fields and order the fields correctly
assert len(gt_trace_arr) == len(gte_trace_arr) + 1
gt_trace_arr = gt_trace_arr[:-1]
gt_list = gt_trace_arr[['cte', 'psi']].tolist()
gte_list = gte_trace_arr[['cte', 'psi']].tolist()
trace_pairs.append((gt_list, gte_list))
print("# Traces:", len(trace_pairs))
print("# States in each trace:", len(trace_pairs[0][1]))
# TODO Merge into parallel monitoring
print("Monitor with teacher defined predicate")
teacher = DTreeGEMStanleyGurobiStabilityTeacher(norm_ord=2, ultimate_bound=0.2)
in_spec_trace_list = [
[teacher.is_safe_state((0.0,) + tuple(-v for v in gt) + gte)
for gt, gte in zip(gt_list, gte_list)]
for gt_list, gte_list in trace_pairs
]
num_safe_total_list = [(sum(bool_trace), len(bool_trace)) for bool_trace in in_spec_trace_list]
in_pred_arr, total_arr = np.array(num_safe_total_list).T
pass_rate = in_pred_arr / total_arr
print(f"Min % positive states: {np.min(pass_rate)*100:.2f}%")
print(f"Avg % positive states: {np.mean(pass_rate)*100:.2f}%")
print(f"Max % positive states: {np.max(pass_rate)*100:.2f}%")
gte_var_names = ["d_e", "psi_e"]
gt_vars = z3.Reals(gt_var_names)
gte_vars = z3.Reals(gte_var_names)
state_vars = [z3.Real(f"x_{i}") for i in range(3)]
perc_vars = [z3.Real(f"z_{i}") for i in range(2)]
with open(AAP_JSON_FILE) as json_fp:
data = json.load(json_fp)
aap_pred = build_aap_predicate(state_vars, perc_vars, data)
# NOTE temporarily convert state/latent variables to gt variables
subs = [(x, -gt) for x, gt in zip(state_vars[1:], gt_vars)] + \
[(p, gte) for p, gte in zip(perc_vars, gte_vars)]
aap_pred = z3.substitute(aap_pred, *subs)
# check_gt_in_aap(aap_pred, gt_vars, gte_vars)
aap_pred = z3.simplify(aap_pred)
print("="*80)
in_aap_trace_list = parallel_monitoring(
trace_pairs, gt_var_names, gte_var_names, aap_pred, NUM_JOBS
)
# Calculate agreeing rate
in_aap_table = np.array(in_aap_trace_list, dtype=bool)
in_spec_table = np.array(in_spec_trace_list, dtype=bool)
imply_rates = [sum(is_imply_trace)/len(is_imply_trace)
for is_imply_trace in np.logical_or(np.logical_not(in_aap_table), in_spec_table)]
print("="*80)
print(f"Min % imply: {np.min(imply_rates)*100:.2f}%")
print(f"Avg % imply: {np.mean(imply_rates)*100:.2f}%")
print(f"Max % imply: {np.max(imply_rates)*100:.2f}%")
agree_rates = [sum(is_agree_trace)/len(is_agree_trace)
for is_agree_trace in np.logical_not(np.logical_xor(in_aap_table, in_spec_table))]
print("="*80)
print(f"Min % agree: {np.min(agree_rates)*100:.2f}%")
print(f"Avg % agree: {np.mean(agree_rates)*100:.2f}%")
print(f"Max % agree: {np.max(agree_rates)*100:.2f}%")
def main_agbot_stanley():
NUM_JOBS = 40
AAP_JSON_FILE = "diff0_0/out/dtree_synth.5x5.out.json"
TRACE_PKL_FILE = "data/agbot_stanley-800_traces-500_psicte.pickle"
with open(TRACE_PKL_FILE, "rb") as pkl:
trace_arr_pairs = pickle.load(pkl)
trace_pairs = []
for gt_trace_arr, gte_trace_arr in trace_arr_pairs[:10]:
# Extract only relevant fields and order the fields correctly
assert len(gt_trace_arr) == len(gte_trace_arr) + 1
gt_trace_arr = gt_trace_arr[:-1]
gt_list = gt_trace_arr[['cte', 'psi']].tolist()
gte_list = gte_trace_arr[['cte', 'psi']].tolist()
trace_pairs.append((gt_list, gte_list))
print("# Traces:", len(trace_pairs))
print("# States in each trace:", len(trace_pairs[0][1]))
gt_var_names = ["d", "psi"]
gte_var_names = ["d_e", "psi_e"]
gt_vars = z3.Reals(gt_var_names)
gte_vars = z3.Reals(gte_var_names)
# 0.304 = 0.5*0.76 where 0.76 is the cornrow width
safety_pred = z3.And(-0.304 <= gt_vars[0], gt_vars[0] <= 0.304)
print("Safety predicate:", safety_pred)
print("Monitor with safety predicate")
parallel_monitoring(
trace_pairs, gt_var_names, gte_var_names, safety_pred, NUM_JOBS
)
# TODO Finish implementation once AAP for agbot is available