Skip to content
Snippets Groups Projects
Commit bb83193a authored by chsieh16's avatar chsieh16
Browse files

Rename and save pre.data for each partition

parent 66e57949
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@
import itertools
import json
import matplotlib.pyplot as plt
import pathlib
import pickle
import traceback
from typing import Dict, Hashable, List, Tuple
......@@ -115,6 +116,7 @@ def validate_cexs(state_dim: int, perc_dim: int,
print("Spurious CEXs:", *spurious_cexs, sep='\n')
return False
def search_part(partition, state):
assert len(partition) == len(state)
bounds = []
......@@ -122,7 +124,7 @@ def search_part(partition, state):
i = np.searchsorted(sorted_list, v)
if i == 0 or i == len(sorted_list):
return None
bounds.append( (sorted_list[i-1], sorted_list[i]) )
bounds.append((sorted_list[i-1], sorted_list[i]))
return tuple(bounds)
......@@ -180,7 +182,7 @@ def main():
)
result = []
for part, (pos_exs, neg_exs, num_nan) in part_to_examples.items():
for i, (part, (pos_exs, neg_exs, num_nan)) in enumerate(part_to_examples.items()):
print("#"*80)
print(f"# positive: {len(pos_exs)}; "
f"# negative: {len(neg_exs)}; "
......@@ -196,16 +198,6 @@ def main():
perc_dim=teacher.perc_dim, timeout=20000)
learner.set_grammar([(Teacher.PERC_GT, np.zeros(2))])
if pos_exs:
pos_fv_arr = np.asfarray([learner._s2f_func(exs) for exs in pos_exs])
plt.scatter(pos_fv_arr[:, 0], pos_fv_arr[:, 1], c="g", marker="o")
if neg_exs:
neg_fv_arr = np.asfarray([learner._s2f_func(exs) for exs in neg_exs])
plt.scatter(neg_fv_arr[:, 0], neg_fv_arr[:, 1], c="r", marker="x")
plt.show()
continue # XXX Temporary skip learning and only plot feature vectors
learner.add_positive_examples(*pos_exs)
learner.add_negative_examples(*neg_exs)
try:
......@@ -227,12 +219,15 @@ def main():
"result": traceback.format_exc()})
print(e)
finally:
data_file = pathlib.Path("out/pre.data")
data_file.rename(f"out/part-{i:03}-pre.data")
del teacher
del learner
with open(f"out/dtree_synth.{NUM_Y_PARTS}x{NUM_YAW_PARTS}.out.json", "w") as f:
json.dump(result, f)
if __name__ == "__main__":
# test_synth_dtree()
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment