Skip to content
Snippets Groups Projects
Commit 7bc53946 authored by chsieh16's avatar chsieh16
Browse files

Add npy file for AgBot case

parent 85362be9
No related branches found
No related tags found
No related merge requests found
File added
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import itertools import itertools
import json import json
from typing import Dict, Hashable, Literal from typing import Dict, Hashable, Iterable, Literal, Tuple
import numpy as np import numpy as np
from dtree_synth import DataSet, search_part, synth_dtree_per_part from dtree_synth import DataSet, search_part, synth_dtree_per_part
...@@ -11,7 +11,7 @@ from agbot_stanley_teacher import DTreeAgBotStanleyGurobiTeacher ...@@ -11,7 +11,7 @@ from agbot_stanley_teacher import DTreeAgBotStanleyGurobiTeacher
def load_examples_from_numpy_array( def load_examples_from_numpy_array(
data_5d_numpy: np.ndarray, data_5d: Iterable[Tuple[float, ...]],
teacher: DTreeAgBotStanleyGurobiTeacher, teacher: DTreeAgBotStanleyGurobiTeacher,
partition) -> Dict[Hashable, DataSet]: partition) -> Dict[Hashable, DataSet]:
# x_arr[:-1], except the last one, x_arr[1:] except the first one # x_arr[:-1], except the last one, x_arr[1:] except the first one
...@@ -19,7 +19,7 @@ def load_examples_from_numpy_array( ...@@ -19,7 +19,7 @@ def load_examples_from_numpy_array(
ret = {part: DataSet() for part in itertools.product(*bound_list)} ret = {part: DataSet() for part in itertools.product(*bound_list)}
num_excl_samples = 0 num_excl_samples = 0
for dpoint in data_5d_numpy: for dpoint in data_5d:
vehicle_state = dpoint[0:teacher.state_dim] vehicle_state = dpoint[0:teacher.state_dim]
part = search_part(partition, vehicle_state) part = search_part(partition, vehicle_state)
if part is None: if part is None:
...@@ -28,19 +28,19 @@ def load_examples_from_numpy_array( ...@@ -28,19 +28,19 @@ def load_examples_from_numpy_array(
if np.any(np.isnan(dpoint)): if np.any(np.isnan(dpoint)):
ret[part].num_nan_dps += 1 ret[part].num_nan_dps += 1
elif teacher.is_safe_state(dpoint): elif teacher.is_safe_state(dpoint):
ret[part].safe_dps.append(tuple(dpoint)) ret[part].safe_dps.append(dpoint)
else: else:
ret[part].unsafe_dps.append(tuple(dpoint)) ret[part].unsafe_dps.append(dpoint)
print("# samples not in any selected parts:", num_excl_samples) print("# samples not in any selected parts:", num_excl_samples)
return ret return ret
def main(dom: Literal["concat", "diff"], ult_bound: float): def main(dom: Literal["concat", "diff"], ult_bound: float):
NPZ_FILE_PATH = "data/perceptor-agbot-collect_images_2021-10-29-01-37-44-0.0-50.0.npz" NPY_FILE_PATH = "data/400_truths-uniform_partition_20x20-0.228m-pi_6-agbot-2021-10-29-01-37-44.npy"
print("Loading examples from .npz") print("Loading examples from .npy")
with np.load(NPZ_FILE_PATH) as npz_data: data_5d_structured_arr = np.load(NPY_FILE_PATH)
data_5d_numpy = npz_data['dps_arr'] data_5d = data_5d_structured_arr[['x', 'y', 'yaw', 'cte', 'psi']].tolist()
# Partitions on prestate # Partitions on prestate
...@@ -48,11 +48,11 @@ def main(dom: Literal["concat", "diff"], ult_bound: float): ...@@ -48,11 +48,11 @@ def main(dom: Literal["concat", "diff"], ult_bound: float):
X_ARR = np.array([-X_LIM, X_LIM]) X_ARR = np.array([-X_LIM, X_LIM])
PRE_Y_LIM = 0.228 PRE_Y_LIM = 0.228
NUM_Y_PARTS = 10 NUM_Y_PARTS = 5
Y_ARR = np.linspace(-PRE_Y_LIM, PRE_Y_LIM, NUM_Y_PARTS + 1) Y_ARR = np.linspace(-PRE_Y_LIM, PRE_Y_LIM, NUM_Y_PARTS + 1)
PRE_YAW_LIM = np.pi / 6 PRE_YAW_LIM = np.pi / 6
NUM_YAW_PARTS = 10 NUM_YAW_PARTS = 5
YAW_ARR = np.linspace(-PRE_YAW_LIM, PRE_YAW_LIM, NUM_YAW_PARTS + 1) YAW_ARR = np.linspace(-PRE_YAW_LIM, PRE_YAW_LIM, NUM_YAW_PARTS + 1)
PARTITION = (X_ARR, Y_ARR, YAW_ARR) PARTITION = (X_ARR, Y_ARR, YAW_ARR)
...@@ -62,7 +62,7 @@ def main(dom: Literal["concat", "diff"], ult_bound: float): ...@@ -62,7 +62,7 @@ def main(dom: Literal["concat", "diff"], ult_bound: float):
NORM_ORD = 1 NORM_ORD = 1
teacher = DTreeAgBotStanleyGurobiTeacher(norm_ord=NORM_ORD, ultimate_bound=ULT_BOUND) teacher = DTreeAgBotStanleyGurobiTeacher(norm_ord=NORM_ORD, ultimate_bound=ULT_BOUND)
part_to_examples = load_examples_from_numpy_array(data_5d_numpy, teacher, PARTITION) part_to_examples = load_examples_from_numpy_array(data_5d, teacher, PARTITION)
# Print statistics about training data points # Print statistics about training data points
print("#"*80) print("#"*80)
...@@ -75,7 +75,7 @@ def main(dom: Literal["concat", "diff"], ult_bound: float): ...@@ -75,7 +75,7 @@ def main(dom: Literal["concat", "diff"], ult_bound: float):
if len(unsafe_dps) > 0: if len(unsafe_dps) > 0:
print(f"Part Index {i}:", f"y in [{lb[1]:.03}, {ub[1]:.03}] (m);", f"θ in [{lb[2]:.03}, {ub[2]:.03}] (deg);", print(f"Part Index {i}:", f"y in [{lb[1]:.03}, {ub[1]:.03}] (m);", f"θ in [{lb[2]:.03}, {ub[2]:.03}] (deg);",
f"# safe: {len(safe_dps)}", f"# unsafe: {len(unsafe_dps):03}", f"# NaN: {num_nan}") f"# safe: {len(safe_dps)}", f"# unsafe: {len(unsafe_dps)}", f"# NaN: {num_nan}")
return return
...@@ -102,4 +102,4 @@ def main(dom: Literal["concat", "diff"], ult_bound: float): ...@@ -102,4 +102,4 @@ def main(dom: Literal["concat", "diff"], ult_bound: float):
if __name__ == "__main__": if __name__ == "__main__":
main("concat", 0.0) main("concat", 0.5)
\ No newline at end of file \ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment