Newer
Older
from typing import Callable, Dict, Hashable, List, Literal, Tuple
from learner_base import LearnerBase
from teacher_base import TeacherBase
@dataclasses.dataclass
class DataSet:
safe_dps: List[Tuple[float, ...]] = dataclasses.field(default_factory=list)
unsafe_dps: List[Tuple[float, ...]] = dataclasses.field(default_factory=list)
num_nan_dps: int = 0
def synth_dtree(learner: LearnerBase, teacher: TeacherBase, num_max_iterations: int = 10):
for k in range(num_max_iterations):
print(f"Iteration {k}:", sep='')
print("learning ....")
print(f"candidate: {candidate}")
past_candidate_list.append(candidate)
# QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
result = teacher.check(candidate)
if result == z3.sat:
negative_examples = teacher.model()
print(f"negative examples: {negative_examples}")
# assert validate_cexs(teacher, candidate, negative_examples)
learner.add_negative_examples(*negative_examples)
continue
elif result == z3.unsat:
return True, (k, z3.simplify(candidate, arith_lhs=True)), (teacher_time,learner_time)
return False, f"Reason Unknown {teacher.reason_unknown()}", (teacher_time,learner_time)
return False, f"Reached max iteration {num_max_iterations}.", (teacher_time,learner_time)
cexs: List[Tuple[float]]) -> bool:
spurious_cexs = [cex for cex in cexs
if teacher.is_spurious_example(teacher.state_dim, teacher.perc_dim, candidate, cex)]
if not spurious_cexs:
return True
else:
print("Spurious CEXs:", *spurious_cexs, sep='\n')
return False
def search_part(partition, state):
assert len(partition) == len(state)
bounds = []
for sorted_list, v in zip(partition, state):
i = np.searchsorted(sorted_list, v)
if i == 0 or i == len(sorted_list):
return None
bounds.append((sorted_list[i-1], sorted_list[i]))
def synth_dtree_per_part(
part_to_examples: Dict[Hashable, DataSet],
teacher_builder: Callable[[], TeacherBase],
learner_builder: Callable[[TeacherBase], LearnerBase],
num_max_iter: int,
ult_bound: float = 0.0,
feature_domain: Literal["concat", "diff"] = "diff") -> List[Dict]:
for i, (part, dataset) in enumerate(part_to_examples.items()):
safe_dps, unsafe_dps, num_nan = dataset.safe_dps, dataset.unsafe_dps, dataset.num_nan_dps
# if a partition has no safe datapoints, then skip the partion
if len(safe_dps) == 0:
result.append({"part": part,
"status": "skip",
"result": "No safe datapoints."})
continue
print(f"# safe: {len(safe_dps)}; "
f"# unsafe: {len(unsafe_dps)}; "
f"# NaN: {num_nan}")
lb, ub = np.asfarray(part).T
# NOTE We create new teacher and learner for each part to avoid solutions from other parts
teacher.set_old_state_bound(lb=lb, ub=ub)
learner = learner_builder(teacher)
learner.add_positive_examples(*safe_dps)
learner.add_negative_examples(*unsafe_dps)
print(f"Found? {found}")
if found:
k, z3_expr = ret
solver = z3.Solver()
solver.add(z3_expr)
smtlib = solver.to_smt2()
"feature_domain": feature_domain,
"ultimate_bound": ult_bound,
"result": {"k": k, "formula": z3_expr.sexpr(), "smtlib": z3_expr.serialize()},
"teacher time": time_info[0],
"learner time": time_info[1]})
else:
result.append({"part": part,
"feature_domain": feature_domain,
"ultimate_bound": ult_bound,
"status": "not found",
"result": ret})
except KeyboardInterrupt:
print("User pressed Ctrl+C. Skip all remaining partition.")
break # NOTE finally block is executed before break
except Exception as e:
result.append({"part": part,
"status": "exception",
"result": traceback.format_exc()})
print(e)
finally:
data_file = pathlib.Path("out/pre.data")
data_file.rename(f"out/part-{i:03}-pre.data")