import dataclasses import pathlib import traceback from typing import Callable, Dict, Hashable, List, Literal, Tuple import numpy as np import z3 from z3 import * from learner_base import LearnerBase from teacher_base import TeacherBase import time @dataclasses.dataclass class DataSet: safe_dps: List[Tuple[float, ...]] = dataclasses.field(default_factory=list) unsafe_dps: List[Tuple[float, ...]] = dataclasses.field(default_factory=list) num_nan_dps: int = 0 def synth_dtree(learner: LearnerBase, teacher: TeacherBase, num_max_iterations: int = 10): past_candidate_list = [] teacher_time = 0.0 learner_time = 0.0 for k in range(num_max_iterations): print("="*80) print(f"Iteration {k}:", sep='') print("learning ....") lt_st = time.time() candidate = learner.learn() lt_end = time.time() - lt_st learner_time += lt_end print("done learning") print(f"candidate: {candidate}") past_candidate_list.append(candidate) # QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES tt_st = time.time() result = teacher.check(candidate) tt_end = time.time() - tt_st teacher_time += tt_end print(f"Satisfiability: {result}") if result == z3.sat: negative_examples = teacher.model() assert len(negative_examples) > 0 print(f"negative examples: {negative_examples}") # assert validate_cexs(teacher, candidate, negative_examples) learner.add_negative_examples(*negative_examples) continue elif result == z3.unsat: print("we are done!") z3_expr = z3.simplify(candidate, arith_lhs=True) ret = {"k": k, "formula": z3_expr.sexpr(), "smtlib": z3_expr.serialize()} return True, ret, (teacher_time,learner_time) else: assert result == z3.unknown z3_expr = z3.simplify(candidate, arith_lhs=True) ret = { "reason": f"Reason Unknown {teacher.reason_unknown()}", "k": k, "formula": z3_expr.sexpr(), "smtlib": z3_expr.serialize() } return False, ret, (teacher_time,learner_time) return False, f"Reached max iteration {num_max_iterations}.", (teacher_time,learner_time) def tacticSimplify(formula:BoolRef): print("in custom simplify" + "-"*50) gtdc = formula set_option(max_args = 10000000, max_lines = 1000000, max_depth = 10000000, max_visited = 1000000) set_option(html_mode=False) set_fpa_pretty(flag=False) g = Goal() g.add(gtdc) works = Repeat(Then( Repeat(OrElse(Tactic('split-clause'),Tactic('skip'))), OrElse(Tactic('ctx-solver-simplify'),Tactic('skip')), OrElse(Tactic('unit-subsume-simplify'),Tactic('skip')), OrElse(Tactic('propagate-ineqs'),Tactic('skip')), OrElse(Tactic('purify-arith'),Tactic('skip')), OrElse(Tactic('ctx-simplify'),Tactic('skip')), OrElse(Tactic('dom-simplify'),Tactic('skip')), OrElse(Tactic('propagate-values'),Tactic('skip')), OrElse(Tactic('simplify'),Tactic('skip')), OrElse(Tactic('aig'),Tactic('skip')), OrElse(Tactic('degree-shift'),Tactic('skip')), OrElse(Tactic('factor'),Tactic('skip')), OrElse(Tactic('lia2pb'),Tactic('skip')), OrElse(Tactic('recover-01'),Tactic('skip')), OrElse(Tactic('elim-term-ite'),Tactic('skip')), #must to remove ite OrElse(Tactic('injectivity'),Tactic('skip')), OrElse(Tactic('snf'),Tactic('skip')), OrElse(Tactic('reduce-args'),Tactic('skip')), OrElse(Tactic('elim-and'),Tactic('skip')), OrElse(Tactic('symmetry-reduce'),Tactic('skip')), OrElse(Tactic('macro-finder'),Tactic('skip')), OrElse(Tactic('quasi-macros'),Tactic('skip')), )) result = works(g) #result = works1(g) # split_all = # print str(result) # result = [[ "d1", "d2", "d3"], #= conjunct && conjunct # [ "d4", "d5", "d6"]] # remove empty subgoals and check if resultant list is empty. result = filter(None, result) if not result: print("there is an error in the custom simplify Z3") sys.exit(-9) # return result results = list(result) completeConjunct = [] for i in range(0,len(results)): conjunction = results[i] completeDisjunct = [] for literal in conjunction: #if i >= 1 and literal in result[i-1]: # continue completeDisjunct.append(literal) completeConjunct.append(z3.simplify(And(completeDisjunct))) if len(completeConjunct) == 1: print(completeConjunct[0]) return completeConjunct[0] simplifiedGtdc = Or(completeConjunct) print("custom simplified gtdc" + "-"*50) print(simplifiedGtdc) return simplifiedGtdc def validate_cexs(teacher: TeacherBase, candidate: z3.BoolRef, cexs: List[Tuple[float]]) -> bool: spurious_cexs = [cex for cex in cexs if teacher.is_spurious_example(teacher.state_dim, teacher.perc_dim, candidate, cex)] if not spurious_cexs: return True else: print("Spurious CEXs:", *spurious_cexs, sep='\n') return False def search_part(partition, state): assert len(partition) == len(state) bounds = [] for sorted_list, v in zip(partition, state): i = np.searchsorted(sorted_list, v) if i == 0 or i == len(sorted_list): return None bounds.append((sorted_list[i-1], sorted_list[i])) return tuple(bounds) def synth_dtree_per_part( part_to_examples: Dict[Hashable, DataSet], teacher_builder: Callable[[], TeacherBase], learner_builder: Callable[[TeacherBase], LearnerBase], num_max_iter: int, ult_bound: float = 0.0, feature_domain: Literal["concat", "diff"] = "diff") -> List[Dict]: result = [] for i, (part, dataset) in enumerate(part_to_examples.items()): safe_dps, unsafe_dps, num_nan = dataset.safe_dps, dataset.unsafe_dps, dataset.num_nan_dps # if a partition has no safe datapoints, then skip the partion if len(safe_dps) == 0: result.append({"part": part, "status": "skip", "result": "No safe datapoints."}) continue print("#"*80) print(f"# safe: {len(safe_dps)}; " f"# unsafe: {len(unsafe_dps)}; " f"# NaN: {num_nan}") lb, ub = np.asfarray(part).T # NOTE We create new teacher and learner for each part to avoid solutions from other parts # TODO incremental solving teacher = teacher_builder() teacher.set_old_state_bound(lb=lb, ub=ub) learner = learner_builder(teacher) learner.add_positive_examples(*safe_dps) learner.add_negative_examples(*unsafe_dps) try: found, ret, time_info = synth_dtree(learner, teacher, num_max_iterations=num_max_iter) print(f"Found? {found}") if found: result.append({"part": part, "feature_domain": feature_domain, "ultimate_bound": ult_bound, "status": "found", "result": ret, "teacher time": time_info[0], "learner time": time_info[1]}) else: result.append({"part": part, "feature_domain": feature_domain, "ultimate_bound": ult_bound, "status": "not found", "result": ret, "teacher time": time_info[0], "learner time": time_info[1]}) except KeyboardInterrupt: print("User pressed Ctrl+C. Skip all remaining partition.") break # NOTE finally block is executed before break except Exception as e: result.append({"part": part, "status": "exception", "result": traceback.format_exc()}) print(e) finally: data_file = pathlib.Path("out/pre.data") data_file.rename(f"out/part-{i:03}-pre.data") del teacher del learner return result