import dataclasses import pathlib import traceback from typing import Callable, Dict, Hashable, List, Literal, Tuple import numpy as np import z3 from learner_base import LearnerBase from teacher_base import TeacherBase @dataclasses.dataclass class DataSet: safe_dps: List[Tuple[float, ...]] = dataclasses.field(default_factory=list) unsafe_dps: List[Tuple[float, ...]] = dataclasses.field(default_factory=list) num_nan_dps: int = 0 def synth_dtree(learner: LearnerBase, teacher: TeacherBase, num_max_iterations: int = 10): past_candidate_list = [] for k in range(num_max_iterations): print("="*80) print(f"Iteration {k}:", sep='') print("learning ....") candidate = learner.learn() print("done learning") print(f"candidate: {candidate}") past_candidate_list.append(candidate) # QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES result = teacher.check(candidate) print(f"Satisfiability: {result}") if result == z3.sat: negative_examples = teacher.model() assert len(negative_examples) > 0 print(f"negative examples: {negative_examples}") # assert validate_cexs(teacher, candidate, negative_examples) learner.add_negative_examples(*negative_examples) continue elif result == z3.unsat: print("we are done!") return True, (k, z3.simplify(candidate, arith_lhs=True).sexpr()) else: return False, f"Reason Unknown {teacher.reason_unknown()}" return False, f"Reached max iteration {num_max_iterations}." def validate_cexs(teacher: TeacherBase, candidate: z3.BoolRef, cexs: List[Tuple[float]]) -> bool: spurious_cexs = [cex for cex in cexs if teacher.is_spurious_example(teacher.state_dim, teacher.perc_dim, candidate, cex)] if not spurious_cexs: return True else: print("Spurious CEXs:", *spurious_cexs, sep='\n') return False def search_part(partition, state): assert len(partition) == len(state) bounds = [] for sorted_list, v in zip(partition, state): i = np.searchsorted(sorted_list, v) if i == 0 or i == len(sorted_list): return None bounds.append((sorted_list[i-1], sorted_list[i])) return tuple(bounds) def synth_dtree_per_part( part_to_examples: Dict[Hashable, DataSet], teacher_builder: Callable[[], TeacherBase], learner_builder: Callable[[], LearnerBase], num_max_iter: int, ult_bound: float = 0.0, feature_domain: Literal["concat", "diff"] = "diff") -> List[Dict]: result = [] for i, (part, dataset) in enumerate(part_to_examples.items()): safe_dps, unsafe_dps, num_nan = dataset.safe_dps, dataset.unsafe_dps, dataset.num_nan_dps print("#"*80) print(f"# safe: {len(safe_dps)}; " f"# unsafe: {len(unsafe_dps)}; " f"# NaN: {num_nan}") lb, ub = np.asfarray(part).T # NOTE We create new teacher and learner for each part to avoid solutions from other parts # TODO incremental solving teacher = teacher_builder() teacher.set_old_state_bound(lb=lb, ub=ub) learner = learner_builder() learner.add_positive_examples(*safe_dps) learner.add_negative_examples(*unsafe_dps) try: found, ret = synth_dtree(learner, teacher, num_max_iterations=num_max_iter) print(f"Found? {found}") if found: k, expr = ret result.append({"part": part, "feature_domain": feature_domain, "ultimate_bound": ult_bound, "status": "found", "result": {"k": k, "formula": expr}}) else: result.append({"part": part, "feature_domain": feature_domain, "ultimate_bound": ult_bound, "status": "not found", "result": ret}) except KeyboardInterrupt: print("User pressed Ctrl+C. Skip all remaining partition.") break # NOTE finally block is executed before break except Exception as e: result.append({"part": part, "status": "exception", "result": traceback.format_exc()}) print(e) finally: data_file = pathlib.Path("out/pre.data") data_file.rename(f"out/part-{i:03}-pre.data") del teacher del learner return result