Skip to content
Snippets Groups Projects
Commit b8798a30 authored by aastorg2's avatar aastorg2
Browse files

finished teacher learner interface for agbot case study

parent 219b84fd
No related branches found
No related tags found
No related merge requests found
......@@ -35,50 +35,162 @@ def load_examples_from_npz(file_name: str, teacher:AgBotTeacher, partition):
bound_list = list(list(zip(x_arr[:-1], x_arr[1:])) for x_arr in partition)
ret = {part: [[], [], 0] for part in itertools.product(*bound_list)}
print(teacher.is_positive_example(data_5d[0]))
exit(0)
num_excl_samples =0
for dpoint in data_5d:
x,y,theta = dpoint[0], dpoint[1], dpoint[2]
#print(x,y,tetha)
#print(dpoint)
vehicle_state = dpoint[0:teacher.state_dim]
part = search_part(partition, vehicle_state)
if part is None:
num_excl_samples += 1
continue
if np.any(np.isnan(dpoint)):
ret[part][2] += 1
elif teacher.is_safe_state(dpoint):
ret[part][0].append(dpoint)
else:
ret[part][1].append(dpoint)
#print(dpoint)
print("# samples not in any selected parts:", num_excl_samples)
return ret
def synth_dtree(learner: Learner, teacher: AgBotTeacher, num_max_iterations: int = 10):
past_candidate_list = []
for k in range(num_max_iterations):
print("="*80)
print(f"Iteration {k}:", sep='')
print("learning ....")
candidate = learner.learn()
print("done learning")
print(f"candidate: {candidate}")
past_candidate_list.append(candidate)
# QUERYING TEACHER IF THERE ARE NEGATIVE EXAMPLES
result = teacher.check(candidate)
print(f"Satisfiability: {result}")
if result == z3.sat:
negative_examples = teacher.model()
assert len(negative_examples) > 0
print(f"negative examples: {negative_examples}")
# assert validate_cexs(teacher.state_dim, teacher.perc_dim, candidate, negative_examples)
learner.add_negative_examples(*negative_examples)
continue
elif result == z3.unsat:
print("we are done!")
return True, (k, z3.simplify(candidate, arith_lhs=True).sexpr())
else:
return False, f"Reason Unknown {teacher.reason_unknown()}"
return False, f"Reached max iteration {num_max_iterations}."
with open(file_name, "rb") as pickle_file_io:
pkl_data = pickle.load(pickle_file_io)
truth_samples_seq = pkl_data["truth_samples"]
# Convert from sampled states and percepts to positive and negative examples for learning
pos_exs, neg_exs, num_excl_exs = [], [], 0
for _, ss in truth_samples_seq:
for s in ss:
if np.any(np.isnan(s)):
num_excl_exs += 1
elif spec(s):
pos_exs.append(s)
else:
neg_exs.append(s)
print("# Excluded NaN examples:", num_excl_exs)
return pos_exs, neg_exs
if __name__ == "__main__":
NPZ_FILE_PATH="data/perceptor-agbot-collect_images_2021-10-29-01-37-44-0.0-50.0.npz"
# Partitions on prestate
X_LIM = np.inf
X_ARR = np.array([-X_LIM, X_LIM])
Y_LIM = 1.2
NUM_Y_PARTS = 4
Y_ARR = np.linspace(-Y_LIM, Y_LIM, NUM_Y_PARTS + 1)
YAW_LIM = np.pi / 12
PRE_Y_LIM = 0.228
NUM_Y_PARTS = 10
Y_ARR = np.linspace(-PRE_Y_LIM, PRE_Y_LIM, NUM_Y_PARTS + 1)
#temp = np.round(Y_ARR, decimals=4)
PRE_YAW_LIM = np.pi / 6
NUM_YAW_PARTS = 10
YAW_ARR = np.linspace(-YAW_LIM, YAW_LIM, NUM_YAW_PARTS + 1)
YAW_ARR = np.linspace(-PRE_YAW_LIM, PRE_YAW_LIM, NUM_YAW_PARTS + 1)
PARTITION = (X_ARR, Y_ARR, YAW_ARR)
NUM_MAX_ITER = 500
FEATURE_DOMAIN = "concat"
ULT_BOUND = 0.0
NORM_ORD = 1
teacher= AgBotTeacher(norm_ord=NORM_ORD, ultimate_bound=ULT_BOUND)
part_to_examples = load_examples_from_npz(NPZ_FILE_PATH, teacher, PARTITION)
# Print statistics about training data points
print("#"*80)
print("Parts with unsafe data points:")
for i, (part, (safe_dps, unsafe_dps, num_nan)) in enumerate(part_to_examples.items()):
lb, ub = np.asfarray(part).T
lb[2] = np.rad2deg(lb[2])
ub[2] = np.rad2deg(ub[2])
if len(unsafe_dps) > 0:
print(f"Part Index {i}:", f"y in [{lb[1]:.03}, {ub[1]:.03}] (m);", f"θ in [{lb[2]:.03}, {ub[2]:.03}] (deg);",
f"# safe: {len(safe_dps)}", f"# unsafe: {len(unsafe_dps):03}", f"# NaN: {num_nan}")
main_teacher= AgBotTeacher()
load_examples_from_npz(NPZ_FILE_PATH, main_teacher, PARTITION)
result = []
for i, (part, (safe_dps, unsafe_dps, num_nan)) in enumerate(part_to_examples.items()):
#if not i == 1:
# continue
# if a partition has zero positive datapoints, then skip the partion
if len(safe_dps) == 0:
continue
if len(safe_dps) == 17:
print()
print("#"*80)
print(f"# safe: {len(safe_dps)}; "
f"# unsafe: {len(unsafe_dps)}; "
f"# NaN: {num_nan}")
lb, ub = np.asfarray(part).T
teacher = AgBotTeacher(norm_ord=NORM_ORD, ultimate_bound=ULT_BOUND)
teacher.set_old_state_bound(lb=lb, ub=ub)
learner = Learner(state_dim=teacher.state_dim,
perc_dim=teacher.perc_dim, timeout=20000)
learner.set_grammar([(AgBotTeacher.PERC_GT, np.zeros(2))], FEATURE_DOMAIN)
learner.add_positive_examples(*safe_dps)
learner.add_negative_examples(*unsafe_dps)
try:
found, ret = synth_dtree(learner, teacher,
num_max_iterations=NUM_MAX_ITER)
print(f"Found? {found}")
if found:
k, expr = ret
result.append({"part": part,
"feature_domain": FEATURE_DOMAIN,
"ultimate_bound": ULT_BOUND,
"status": "found",
"result": {"k": k, "formula": expr}})
else:
result.append({"part": part,
"feature_domain": FEATURE_DOMAIN,
"ultimate_bound": ULT_BOUND,
"status": "not found",
"result": ret})
except KeyboardInterrupt:
print("User pressed Ctrl+C. Skip all remaining partition.")
break # NOTE finally block is executed before break
except Exception as e:
result.append({"part": part,
"status": "exception",
"result": traceback.format_exc()})
print(e)
finally:
data_file = pathlib.Path("out/pre.data")
data_file.rename(f"out/part-{i:03}-pre.data")
del teacher
del learner
with open(f"out/dtree_synth.{NUM_Y_PARTS}x{NUM_YAW_PARTS}.out.json", "w") as f:
json.dump(result, f)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment