diff --git a/dtree_synth_agbot.py b/dtree_synth_agbot.py index fe1a7efc0893dc202bfb352e858c230329e775fc..b2dad32ba91a0fccf1a110991b3af9bb54575148 100644 --- a/dtree_synth_agbot.py +++ b/dtree_synth_agbot.py @@ -13,6 +13,15 @@ import z3 from dtree_learner import DTreeLearner as Learner from agbot_stanley_teacher import DTreeAgBotStanleyGurobiTeacher as AgBotTeacher +def search_part(partition, state): + assert len(partition) == len(state) + bounds = [] + for sorted_list, v in zip(partition, state): + i = np.searchsorted(sorted_list, v) + if i == 0 or i == len(sorted_list): + return None + bounds.append((sorted_list[i-1], sorted_list[i])) + return tuple(bounds) def load_examples_from_npz(file_name: str, teacher:AgBotTeacher, partition): print("Loading examples from .npz")