Skip to content
Snippets Groups Projects
dtree_synth.py 5.57 KiB
Newer Older
import dataclasses
import pathlib
import traceback
from typing import Callable, Dict, Hashable, List, Literal, Tuple

import numpy as np
import z3
from learner_base import LearnerBase
from teacher_base import TeacherBase
aastorg2's avatar
aastorg2 committed
import time
@dataclasses.dataclass
class DataSet:
    safe_dps: List[Tuple[float, ...]] = dataclasses.field(default_factory=list)
    unsafe_dps: List[Tuple[float, ...]] = dataclasses.field(default_factory=list)
    num_nan_dps: int = 0
def synth_dtree(learner: LearnerBase, teacher: TeacherBase, num_max_iterations: int = 10):
    past_candidate_list = []
aastorg2's avatar
aastorg2 committed
    teacher_time = 0.0
    learner_time = 0.0
    for k in range(num_max_iterations):
chsieh16's avatar
chsieh16 committed
        print("="*80)
        print(f"Iteration {k}:", sep='')
        print("learning ....")
aastorg2's avatar
aastorg2 committed
        lt_st = time.time()
        candidate = learner.learn()
aastorg2's avatar
aastorg2 committed
        lt_end = time.time() - lt_st
        learner_time += lt_end 

        print("done learning")
        print(f"candidate: {candidate}")
        past_candidate_list.append(candidate)
        # QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
aastorg2's avatar
aastorg2 committed
        tt_st = time.time()
        result = teacher.check(candidate)
aastorg2's avatar
aastorg2 committed
        tt_end = time.time() - tt_st
        teacher_time += tt_end

chsieh16's avatar
chsieh16 committed
        print(f"Satisfiability: {result}")
        if result == z3.sat:
            negative_examples = teacher.model()
            assert len(negative_examples) > 0
            print(f"negative examples: {negative_examples}")

            # assert validate_cexs(teacher, candidate, negative_examples)
            learner.add_negative_examples(*negative_examples)
            continue
        elif result == z3.unsat:
            print("we are done!")
aastorg2's avatar
aastorg2 committed
            return True, (k, z3.simplify(candidate, arith_lhs=True)), (teacher_time,learner_time)
aastorg2's avatar
aastorg2 committed
            return False, f"Reason Unknown {teacher.reason_unknown()}", (teacher_time,learner_time)
aastorg2's avatar
aastorg2 committed
    return False, f"Reached max iteration {num_max_iterations}.", (teacher_time,learner_time)
def validate_cexs(teacher: TeacherBase,
                  candidate: z3.BoolRef,
                  cexs: List[Tuple[float]]) -> bool:
    spurious_cexs = [cex for cex in cexs
                     if teacher.is_spurious_example(teacher.state_dim, teacher.perc_dim, candidate, cex)]
    if not spurious_cexs:
        return True
    else:
        print("Spurious CEXs:", *spurious_cexs, sep='\n')
        return False
def search_part(partition, state):
    assert len(partition) == len(state)
    bounds = []
    for sorted_list, v in zip(partition, state):
        i = np.searchsorted(sorted_list, v)
        if i == 0 or i == len(sorted_list):
            return None
        bounds.append((sorted_list[i-1], sorted_list[i]))
    return tuple(bounds)


def synth_dtree_per_part(
        part_to_examples: Dict[Hashable, DataSet],
        teacher_builder: Callable[[], TeacherBase],
        learner_builder: Callable[[], LearnerBase],
        num_max_iter: int,
        ult_bound: float = 0.0,
        feature_domain: Literal["concat", "diff"] = "diff") -> List[Dict]:
    result = []
    for i, (part, dataset) in enumerate(part_to_examples.items()):
        safe_dps, unsafe_dps, num_nan = dataset.safe_dps, dataset.unsafe_dps, dataset.num_nan_dps
        # if a partition has no safe datapoints, then skip the partion
        if len(safe_dps) == 0:
            result.append({"part": part,
                           "status": "skip",
                           "result": "No safe datapoints."})
            continue

        print("#"*80)
        print(f"# safe: {len(safe_dps)}; "
              f"# unsafe: {len(unsafe_dps)}; "
              f"# NaN: {num_nan}")
        lb, ub = np.asfarray(part).T

        # NOTE We create new teacher and learner for each part to avoid solutions from other parts
        # TODO incremental solving
        teacher = teacher_builder()
        teacher.set_old_state_bound(lb=lb, ub=ub)

        learner = learner_builder()
        learner.add_positive_examples(*safe_dps)
        learner.add_negative_examples(*unsafe_dps)
aastorg2's avatar
aastorg2 committed
            found, ret, time_info = synth_dtree(learner, teacher,
                                     num_max_iterations=num_max_iter)
            print(f"Found? {found}")
            if found:
chsieh16's avatar
chsieh16 committed
                k, z3_expr = ret
                solver = z3.Solver()
                solver.add(z3_expr)
                smtlib = solver.to_smt2()
                result.append({"part": part,
                               "feature_domain": feature_domain,
                               "ultimate_bound": ult_bound,
                               "status": "found",
                               "result": {"k": k, "formula": z3_expr.sexpr(), "smtlib": z3_expr.serialize()},
aastorg2's avatar
aastorg2 committed
                               "teacher time": time_info[0],
                               "learner time": time_info[1]})
            else:
                result.append({"part": part,
                               "feature_domain": feature_domain,
                               "ultimate_bound": ult_bound,
                               "status": "not found",
                               "result": ret})
        except KeyboardInterrupt:
            print("User pressed Ctrl+C. Skip all remaining partition.")
            break  # NOTE finally block is executed before break
        except Exception as e:
            result.append({"part": part,
                           "status": "exception",
                           "result": traceback.format_exc()})
            print(e)
        finally:
            data_file = pathlib.Path("out/pre.data")
            data_file.rename(f"out/part-{i:03}-pre.data")
            del teacher
            del learner
    return result