Newer
Older
from typing import Callable, Dict, Hashable, List, Literal, Tuple
from learner_base import LearnerBase
from teacher_base import TeacherBase
@dataclasses.dataclass
class DataSet:
safe_dps: List[Tuple[float, ...]] = dataclasses.field(default_factory=list)
unsafe_dps: List[Tuple[float, ...]] = dataclasses.field(default_factory=list)
num_nan_dps: int = 0
def synth_dtree(learner: LearnerBase, teacher: TeacherBase, num_max_iterations: int = 10):
for k in range(num_max_iterations):
print(f"Iteration {k}:", sep='')
print("learning ....")
print(f"candidate: {candidate}")
past_candidate_list.append(candidate)
# QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
result = teacher.check(candidate)
if result == z3.sat:
negative_examples = teacher.model()
print(f"negative examples: {negative_examples}")
# assert validate_cexs(teacher, candidate, negative_examples)
learner.add_negative_examples(*negative_examples)
continue
elif result == z3.unsat:
z3_expr = z3.simplify(candidate, arith_lhs=True)
ret = {"k": k, "formula": z3_expr.sexpr(), "smtlib": z3_expr.serialize()}
return True, ret, (teacher_time,learner_time)
z3_expr = z3.simplify(candidate, arith_lhs=True)
ret = {
"reason": f"Reason Unknown {teacher.reason_unknown()}",
"k": k, "formula": z3_expr.sexpr(), "smtlib": z3_expr.serialize()
}
return False, ret, (teacher_time,learner_time)
return False, f"Reached max iteration {num_max_iterations}.", (teacher_time,learner_time)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def tacticSimplify(formula:BoolRef):
print("in custom simplify" + "-"*50)
gtdc = formula
set_option(max_args = 10000000, max_lines = 1000000, max_depth = 10000000, max_visited = 1000000)
set_option(html_mode=False)
set_fpa_pretty(flag=False)
g = Goal()
g.add(gtdc)
works = Repeat(Then(
Repeat(OrElse(Tactic('split-clause'),Tactic('skip'))),
OrElse(Tactic('ctx-solver-simplify'),Tactic('skip')),
OrElse(Tactic('unit-subsume-simplify'),Tactic('skip')),
OrElse(Tactic('propagate-ineqs'),Tactic('skip')),
OrElse(Tactic('purify-arith'),Tactic('skip')),
OrElse(Tactic('ctx-simplify'),Tactic('skip')),
OrElse(Tactic('dom-simplify'),Tactic('skip')),
OrElse(Tactic('propagate-values'),Tactic('skip')),
OrElse(Tactic('simplify'),Tactic('skip')),
OrElse(Tactic('aig'),Tactic('skip')),
OrElse(Tactic('degree-shift'),Tactic('skip')),
OrElse(Tactic('factor'),Tactic('skip')),
OrElse(Tactic('lia2pb'),Tactic('skip')),
OrElse(Tactic('recover-01'),Tactic('skip')),
OrElse(Tactic('elim-term-ite'),Tactic('skip')), #must to remove ite
OrElse(Tactic('injectivity'),Tactic('skip')),
OrElse(Tactic('snf'),Tactic('skip')),
OrElse(Tactic('reduce-args'),Tactic('skip')),
OrElse(Tactic('elim-and'),Tactic('skip')),
OrElse(Tactic('symmetry-reduce'),Tactic('skip')),
OrElse(Tactic('macro-finder'),Tactic('skip')),
OrElse(Tactic('quasi-macros'),Tactic('skip')),
))
result = works(g)
#result = works1(g)
# split_all =
# print str(result)
# result = [[ "d1", "d2", "d3"], #= conjunct && conjunct
# [ "d4", "d5", "d6"]]
# remove empty subgoals and check if resultant list is empty.
result = filter(None, result)
if not result:
print("there is an error in the custom simplify Z3")
sys.exit(-9)
# return result
results = list(result)
completeConjunct = []
for i in range(0,len(results)):
conjunction = results[i]
completeDisjunct = []
for literal in conjunction:
#if i >= 1 and literal in result[i-1]:
# continue
completeDisjunct.append(literal)
completeConjunct.append(z3.simplify(And(completeDisjunct)))
if len(completeConjunct) == 1:
print(completeConjunct[0])
return completeConjunct[0]
simplifiedGtdc = Or(completeConjunct)
print("custom simplified gtdc" + "-"*50)
print(simplifiedGtdc)
return simplifiedGtdc
def validate_cexs(teacher: TeacherBase,
cexs: List[Tuple[float]]) -> bool:
spurious_cexs = [cex for cex in cexs
if teacher.is_spurious_example(teacher.state_dim, teacher.perc_dim, candidate, cex)]
if not spurious_cexs:
return True
else:
print("Spurious CEXs:", *spurious_cexs, sep='\n')
return False
def search_part(partition, state):
assert len(partition) == len(state)
bounds = []
for sorted_list, v in zip(partition, state):
i = np.searchsorted(sorted_list, v)
if i == 0 or i == len(sorted_list):
return None
bounds.append((sorted_list[i-1], sorted_list[i]))
def synth_dtree_per_part(
part_to_examples: Dict[Hashable, DataSet],
teacher_builder: Callable[[], TeacherBase],
learner_builder: Callable[[TeacherBase], LearnerBase],
num_max_iter: int,
ult_bound: float = 0.0,
feature_domain: Literal["concat", "diff"] = "diff") -> List[Dict]:
for i, (part, dataset) in enumerate(part_to_examples.items()):
safe_dps, unsafe_dps, num_nan = dataset.safe_dps, dataset.unsafe_dps, dataset.num_nan_dps
# if a partition has no safe datapoints, then skip the partion
if len(safe_dps) == 0:
result.append({"part": part,
"status": "skip",
"result": "No safe datapoints."})
continue
print(f"# safe: {len(safe_dps)}; "
f"# unsafe: {len(unsafe_dps)}; "
f"# NaN: {num_nan}")
lb, ub = np.asfarray(part).T
# NOTE We create new teacher and learner for each part to avoid solutions from other parts
teacher.set_old_state_bound(lb=lb, ub=ub)
learner = learner_builder(teacher)
learner.add_positive_examples(*safe_dps)
learner.add_negative_examples(*unsafe_dps)
print(f"Found? {found}")
if found:
result.append({"part": part,
"feature_domain": feature_domain,
"ultimate_bound": ult_bound,
"teacher time": time_info[0],
"learner time": time_info[1]})
else:
result.append({"part": part,
"feature_domain": feature_domain,
"ultimate_bound": ult_bound,
"result": ret,
"teacher time": time_info[0],
"learner time": time_info[1]})
except KeyboardInterrupt:
print("User pressed Ctrl+C. Skip all remaining partition.")
break # NOTE finally block is executed before break
except Exception as e:
result.append({"part": part,
"status": "exception",
"result": traceback.format_exc()})
print(e)
finally:
data_file = pathlib.Path("out/pre.data")
data_file.rename(f"out/part-{i:03}-pre.data")