Skip to content
Snippets Groups Projects
monitoring_gem_stanley.py 8.21 KiB
Newer Older
#!/usr/bin/env python3

import json
import pickle
from typing import List, Tuple

from joblib import Parallel, delayed
import numpy as np
import z3

from gem_stanley_teacher import DTreeGEMStanleyGurobiStabilityTeacher


def fp_to_real(val: float):
    assert np.isfinite(val)
    return z3.fpToReal(z3.FPVal(val, z3.Float64()))


def build_aap_predicate(
        state_vars: z3.ExprRef,
        perc_vars: z3.ExprRef,
        aap_json) -> z3.BoolRef:
    found_dtree = [dict(part=entry["part"], **entry["result"])
                   for entry in aap_json if entry["status"] == "found"]
    otherwise_list = []  # type: List[z3.BoolRef]
    aap_formula_list = []
    for m in found_dtree:
        assert len(state_vars) == len(m['part'])

        pre_list = []
        for x, (lb, ub) in zip(state_vars, m['part']):
            if np.isfinite(lb):
                pre_list.append(x >= fp_to_real(lb))
            else:
                assert np.isneginf(lb)
            if np.isfinite(ub):
                pre_list.append(x <= fp_to_real(ub))
            else:
                assert np.isposinf(ub)

        dtree = z3.deserialize(m["smtlib"])
        aap_formula_list.append(z3.Implies(z3.And(*pre_list), dtree))
        otherwise_list.append(z3.Not(z3.And(*pre_list)))
    aap_formula_list.append(
        z3.Implies(
            z3.And(*otherwise_list), z3.BoolVal(False)
        )
    )
    return z3.And(aap_formula_list)


def monitoring(
        gt_list: List[Tuple[float, ...]],
        gte_list: List[Tuple[float, ...]],
        gt_var_names: List[str],
        gte_var_names: List[str],
        aap_pred_smtlib: str):
    assert len(gt_list) == len(gte_list)
    gt_vars = z3.Reals(gt_var_names)
    gte_vars = z3.Reals(gte_var_names)

    aap_pred = z3.deserialize(aap_pred_smtlib)

    bool_list = []  # type: List[z3.BoolRef]
    for gt, gte in zip(gt_list, gte_list):
        subs = [(var, fp_to_real(val)) for var, val in zip(gt_vars, gt)] \
            + [(var, fp_to_real(val)) for var, val in zip(gte_vars, gte)]
        bool_list.append(
            z3.simplify(z3.substitute(aap_pred, *subs))
        )
    bool_list = [z3.is_true(b) for b in bool_list]
    return bool_list
chsieh16's avatar
chsieh16 committed
def parallel_monitoring(trace_pairs, gt_var_names, gte_var_names, pred: z3.BoolRef, num_jobs: int):
    in_pred_trace_list = Parallel(num_jobs)(
        delayed(monitoring)(
            gt_list, gte_list, gt_var_names, gte_var_names,
            pred.serialize())
        for gt_list, gte_list in trace_pairs)
    num_safe_total_list = [(sum(bool_trace), len(bool_trace)) for bool_trace in in_pred_trace_list]
    in_pred_arr, total_arr = np.array(num_safe_total_list).T
    pass_rate = in_pred_arr / total_arr
    print(f"Min % positive states: {np.min(pass_rate)*100:.2f}%")
    print(f"Avg % positive states: {np.mean(pass_rate)*100:.2f}%")
    print(f"Max % positive states: {np.max(pass_rate)*100:.2f}%")

    return in_pred_trace_list


def check_gt_in_aap(
        aap_pred: z3.BoolRef,
        gt_vars: List[z3.ExprRef],
        gte_vars: List[z3.ExprRef]) -> None:
    """
    Check if the AAP predicate holds when the ground truth equals estimate.
    """
    assert z3.is_and(aap_pred)

    print("="*80)
    vals = gt_vars
    sub_vals = [(gt, val) for gt, val in zip(gt_vars, vals)] + \
        [(gte, val) for gte, val in zip(gte_vars, vals)]
    for i, pred in enumerate(aap_pred.children()):
        gt_in_aap = z3.simplify(z3.substitute(pred, *sub_vals))
        if z3.is_true(gt_in_aap):
            continue
        s = z3.Solver()
        s.add(z3.Not(gt_in_aap))
        res = s.check()
        if res != z3.unsat:
            print(f"Part {i}")
            print(s.model())
chsieh16's avatar
chsieh16 committed


def main_gem_stanley():
    AAP_JSON_FILE = "diff0_0/out/dtree_synth.4x10.out.json"
    TRACE_PKL_FILE = "data/gem_stanley-straight-1000_traces-500_psicte.pickle"
    with open(TRACE_PKL_FILE, "rb") as pkl:
        trace_arr_pairs = pickle.load(pkl)
    trace_pairs = []
    for gt_trace_arr, gte_trace_arr in trace_arr_pairs:
        # Extract only relevant fields and order the fields correctly
        assert len(gt_trace_arr) == len(gte_trace_arr) + 1
        gt_trace_arr = gt_trace_arr[:-1]
        gt_list = gt_trace_arr[['cte', 'psi']].tolist()
        gte_list = gte_trace_arr[['cte', 'psi']].tolist()
        trace_pairs.append((gt_list, gte_list))
    print("# Traces:", len(trace_pairs))
    print("# States in each trace:", len(trace_pairs[0][1]))
    # TODO Merge into parallel monitoring
    print("Monitor with teacher defined predicate")
    teacher = DTreeGEMStanleyGurobiStabilityTeacher(norm_ord=2, ultimate_bound=0.2)
    in_spec_trace_list = [
        [teacher.is_safe_state((0.0,) + tuple(-v for v in gt) + gte)
         for gt, gte in zip(gt_list, gte_list)]
        for gt_list, gte_list in trace_pairs
    ]
    num_safe_total_list = [(sum(bool_trace), len(bool_trace)) for bool_trace in in_spec_trace_list]
    in_pred_arr, total_arr = np.array(num_safe_total_list).T
    pass_rate = in_pred_arr / total_arr
    print(f"Min % positive states: {np.min(pass_rate)*100:.2f}%")
    print(f"Avg % positive states: {np.mean(pass_rate)*100:.2f}%")
    print(f"Max % positive states: {np.max(pass_rate)*100:.2f}%")

    gt_var_names = ["d", "psi"]
    gte_var_names = ["d_e", "psi_e"]
    gt_vars = z3.Reals(gt_var_names)
    gte_vars = z3.Reals(gte_var_names)

    state_vars = [z3.Real(f"x_{i}") for i in range(3)]
    perc_vars = [z3.Real(f"z_{i}") for i in range(2)]
    with open(AAP_JSON_FILE) as json_fp:
        data = json.load(json_fp)
    aap_pred = build_aap_predicate(state_vars, perc_vars, data)

    # NOTE temporarily convert state/latent variables to gt variables
    subs = [(x, -gt) for x, gt in zip(state_vars[1:], gt_vars)] + \
        [(p, gte) for p, gte in zip(perc_vars, gte_vars)]
    aap_pred = z3.substitute(aap_pred, *subs)
    # check_gt_in_aap(aap_pred, gt_vars, gte_vars)

    aap_pred = z3.simplify(aap_pred)
    print("="*80)
chsieh16's avatar
chsieh16 committed
    print("Monitor with AAP predicate")
    in_aap_trace_list = parallel_monitoring(
chsieh16's avatar
chsieh16 committed
        trace_pairs, gt_var_names, gte_var_names, aap_pred, NUM_JOBS
    )

    # Calculate agreeing rate
    in_aap_table = np.array(in_aap_trace_list, dtype=bool)
    in_spec_table = np.array(in_spec_trace_list, dtype=bool)

    imply_rates = [sum(is_imply_trace)/len(is_imply_trace)
                   for is_imply_trace in np.logical_or(np.logical_not(in_aap_table), in_spec_table)]
    print("="*80)
    print(f"Min % imply: {np.min(imply_rates)*100:.2f}%")
    print(f"Avg % imply: {np.mean(imply_rates)*100:.2f}%")
    print(f"Max % imply: {np.max(imply_rates)*100:.2f}%")

    agree_rates = [sum(is_agree_trace)/len(is_agree_trace)
                   for is_agree_trace in np.logical_not(np.logical_xor(in_aap_table, in_spec_table))]
    print("="*80)
    print(f"Min % agree: {np.min(agree_rates)*100:.2f}%")
    print(f"Avg % agree: {np.mean(agree_rates)*100:.2f}%")
    print(f"Max % agree: {np.max(agree_rates)*100:.2f}%")

chsieh16's avatar
chsieh16 committed
def main_agbot_stanley():
    NUM_JOBS = 40
    AAP_JSON_FILE = "diff0_0/out/dtree_synth.5x5.out.json"
    TRACE_PKL_FILE = "data/agbot_stanley-800_traces-500_psicte.pickle"
chsieh16's avatar
chsieh16 committed
    with open(TRACE_PKL_FILE, "rb") as pkl:
        trace_arr_pairs = pickle.load(pkl)
    trace_pairs = []
    for gt_trace_arr, gte_trace_arr in trace_arr_pairs[:10]:
        # Extract only relevant fields and order the fields correctly
        assert len(gt_trace_arr) == len(gte_trace_arr) + 1
        gt_trace_arr = gt_trace_arr[:-1]
        gt_list = gt_trace_arr[['cte', 'psi']].tolist()
        gte_list = gte_trace_arr[['cte', 'psi']].tolist()
        trace_pairs.append((gt_list, gte_list))

    print("# Traces:", len(trace_pairs))
    print("# States in each trace:", len(trace_pairs[0][1]))

    gt_var_names = ["d", "psi"]
    gte_var_names = ["d_e", "psi_e"]
    gt_vars = z3.Reals(gt_var_names)
    gte_vars = z3.Reals(gte_var_names)
    # 0.304 = 0.5*0.76 where 0.76 is the cornrow width
    safety_pred = z3.And(-0.304 <= gt_vars[0], gt_vars[0] <= 0.304)
    print("Safety predicate:", safety_pred)
chsieh16's avatar
chsieh16 committed
    print("Monitor with safety predicate")
    parallel_monitoring(
        trace_pairs, gt_var_names, gte_var_names, safety_pred, NUM_JOBS
    )
    # TODO Finish implementation once AAP for agbot is available


if __name__ == "__main__":
chsieh16's avatar
chsieh16 committed
    main_gem_stanley()
    # main_agbot_stanley()