Skip to content
Snippets Groups Projects
dtree_synth.py 8.62 KiB
Newer Older
import dataclasses
import pathlib
import traceback
from typing import Callable, Dict, Hashable, List, Literal, Tuple

import numpy as np
import z3
from learner_base import LearnerBase
from teacher_base import TeacherBase
aastorg2's avatar
aastorg2 committed
import time
@dataclasses.dataclass
class DataSet:
    safe_dps: List[Tuple[float, ...]] = dataclasses.field(default_factory=list)
    unsafe_dps: List[Tuple[float, ...]] = dataclasses.field(default_factory=list)
    num_nan_dps: int = 0
def synth_dtree(learner: LearnerBase, teacher: TeacherBase, num_max_iterations: int = 10):
    past_candidate_list = []
aastorg2's avatar
aastorg2 committed
    teacher_time = 0.0
    learner_time = 0.0
    for k in range(num_max_iterations):
chsieh16's avatar
chsieh16 committed
        print("="*80)
        print(f"Iteration {k}:", sep='')
        print("learning ....")
aastorg2's avatar
aastorg2 committed
        lt_st = time.time()
        candidate = learner.learn()
aastorg2's avatar
aastorg2 committed
        lt_end = time.time() - lt_st
        learner_time += lt_end 

        print("done learning")
        print(f"candidate: {candidate}")
        past_candidate_list.append(candidate)
        # QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
aastorg2's avatar
aastorg2 committed
        tt_st = time.time()
        result = teacher.check(candidate)
aastorg2's avatar
aastorg2 committed
        tt_end = time.time() - tt_st
        teacher_time += tt_end

chsieh16's avatar
chsieh16 committed
        print(f"Satisfiability: {result}")
        if result == z3.sat:
            negative_examples = teacher.model()
            assert len(negative_examples) > 0
            print(f"negative examples: {negative_examples}")

            # assert validate_cexs(teacher, candidate, negative_examples)
            learner.add_negative_examples(*negative_examples)
            continue
        elif result == z3.unsat:
            print("we are done!")
            z3_expr = z3.simplify(candidate, arith_lhs=True)
            ret = {"k": k, "formula": z3_expr.sexpr(), "smtlib": z3_expr.serialize()}
            return True, ret, (teacher_time,learner_time)
            assert result == z3.unknown
            z3_expr = z3.simplify(candidate, arith_lhs=True)
            ret = {
                "reason": f"Reason Unknown {teacher.reason_unknown()}",
                "k": k, "formula": z3_expr.sexpr(), "smtlib": z3_expr.serialize()
            }
            return False, ret, (teacher_time,learner_time)
aastorg2's avatar
aastorg2 committed
    return False, f"Reached max iteration {num_max_iterations}.", (teacher_time,learner_time)
def tacticSimplify(formula:BoolRef):
        print("in custom simplify" + "-"*50)
        gtdc = formula
        
        set_option(max_args = 10000000, max_lines = 1000000, max_depth = 10000000, max_visited = 1000000)
        set_option(html_mode=False)
        set_fpa_pretty(flag=False)

        g = Goal()
        g.add(gtdc)
        
        works = Repeat(Then( 
        Repeat(OrElse(Tactic('split-clause'),Tactic('skip'))),
        OrElse(Tactic('ctx-solver-simplify'),Tactic('skip')),
        OrElse(Tactic('unit-subsume-simplify'),Tactic('skip')),
        OrElse(Tactic('propagate-ineqs'),Tactic('skip')),
        OrElse(Tactic('purify-arith'),Tactic('skip')),
        OrElse(Tactic('ctx-simplify'),Tactic('skip')),
        OrElse(Tactic('dom-simplify'),Tactic('skip')),
        OrElse(Tactic('propagate-values'),Tactic('skip')),
        OrElse(Tactic('simplify'),Tactic('skip')),
        OrElse(Tactic('aig'),Tactic('skip')),
        OrElse(Tactic('degree-shift'),Tactic('skip')),
        OrElse(Tactic('factor'),Tactic('skip')),
        OrElse(Tactic('lia2pb'),Tactic('skip')),
        OrElse(Tactic('recover-01'),Tactic('skip')),
        OrElse(Tactic('elim-term-ite'),Tactic('skip')), #must to remove ite
        OrElse(Tactic('injectivity'),Tactic('skip')),
        OrElse(Tactic('snf'),Tactic('skip')),
        OrElse(Tactic('reduce-args'),Tactic('skip')),
        OrElse(Tactic('elim-and'),Tactic('skip')),
        OrElse(Tactic('symmetry-reduce'),Tactic('skip')),
        OrElse(Tactic('macro-finder'),Tactic('skip')),
        OrElse(Tactic('quasi-macros'),Tactic('skip')),
        ))
        
        
        result = works(g)
        #result = works1(g)
        # split_all =

        # print str(result)
        # result = [[ "d1", "d2", "d3"], #= conjunct && conjunct
        #           [ "d4", "d5", "d6"]]

        # remove empty subgoals and check if resultant list is empty.
        result = filter(None, result)
        if not result:
            print("there is an error in the custom simplify Z3")
            sys.exit(-9)
        
        # return result
        results = list(result)
        completeConjunct = []
        for i in range(0,len(results)):
            conjunction = results[i]
            completeDisjunct = []
            for literal in conjunction:
                #if i >= 1 and  literal in result[i-1]:
                #    continue
                completeDisjunct.append(literal)

            completeConjunct.append(z3.simplify(And(completeDisjunct)))
        if len(completeConjunct) == 1:
            print(completeConjunct[0])
            return completeConjunct[0]
        simplifiedGtdc = Or(completeConjunct)
        print("custom simplified gtdc" + "-"*50)
        print(simplifiedGtdc)
        return simplifiedGtdc
def validate_cexs(teacher: TeacherBase,
                  candidate: z3.BoolRef,
                  cexs: List[Tuple[float]]) -> bool:
    spurious_cexs = [cex for cex in cexs
                     if teacher.is_spurious_example(teacher.state_dim, teacher.perc_dim, candidate, cex)]
    if not spurious_cexs:
        return True
    else:
        print("Spurious CEXs:", *spurious_cexs, sep='\n')
        return False
def search_part(partition, state):
    assert len(partition) == len(state)
    bounds = []
    for sorted_list, v in zip(partition, state):
        i = np.searchsorted(sorted_list, v)
        if i == 0 or i == len(sorted_list):
            return None
        bounds.append((sorted_list[i-1], sorted_list[i]))
    return tuple(bounds)


def synth_dtree_per_part(
        part_to_examples: Dict[Hashable, DataSet],
        teacher_builder: Callable[[], TeacherBase],
        learner_builder: Callable[[TeacherBase], LearnerBase],
        num_max_iter: int,
        ult_bound: float = 0.0,
        feature_domain: Literal["concat", "diff"] = "diff") -> List[Dict]:
    for i, (part, dataset) in enumerate(part_to_examples.items()):
        print(f"Partition # {i}")
        safe_dps, unsafe_dps, num_nan = dataset.safe_dps, dataset.unsafe_dps, dataset.num_nan_dps
        # if a partition has no safe datapoints, then skip the partion
        if len(safe_dps) == 0:
            result.append({"part": part,
                           "status": "skip",
                           "result": "No safe datapoints."})
            continue

        print("#"*80)
        print(f"# safe: {len(safe_dps)}; "
              f"# unsafe: {len(unsafe_dps)}; "
              f"# NaN: {num_nan}")
        lb, ub = np.asfarray(part).T

        # NOTE We create new teacher and learner for each part to avoid solutions from other parts
        # TODO incremental solving
        teacher = teacher_builder()
        teacher.set_old_state_bound(lb=lb, ub=ub)

        learner = learner_builder(teacher)
        learner.add_positive_examples(*safe_dps)
        learner.add_negative_examples(*unsafe_dps)
aastorg2's avatar
aastorg2 committed
            found, ret, time_info = synth_dtree(learner, teacher,
                                     num_max_iterations=num_max_iter)
            print(f"Found? {found}")
            if found:
                result.append({"part": part,
                               "feature_domain": feature_domain,
                               "ultimate_bound": ult_bound,
                               "status": "found",
                               "result": ret,
aastorg2's avatar
aastorg2 committed
                               "teacher time": time_info[0],
                               "learner time": time_info[1]})
            else:
                result.append({"part": part,
                               "feature_domain": feature_domain,
                               "ultimate_bound": ult_bound,
                               "status": "not found",
                               "result": ret,
                               "teacher time": time_info[0],
                               "learner time": time_info[1]})
        except KeyboardInterrupt:
            print("User pressed Ctrl+C. Skip all remaining partition.")
            break  # NOTE finally block is executed before break
        except Exception as e:
            result.append({"part": part,
                           "status": "exception",
                           "result": traceback.format_exc()})
            print(e)
        finally:
            data_file = pathlib.Path("out/pre.data")
            data_file.rename(f"out/part-{i:03}-pre.data")
            del teacher
            del learner
    return result