Skip to content
Snippets Groups Projects
Commit d5b4e8a2 authored by chsieh16's avatar chsieh16
Browse files

Filter data by stability criteria and state partition

parent bc5bdc86
No related branches found
No related tags found
No related merge requests found
File added
import pickle
import numpy as np
# Constants for Stanley controller for GEM
WHEEL_BASE = 1.75 # m
K_P = 0.45
CYCLE_TIME = 0.05 # second
FORWARD_VEL = 2.8 # m/s
STEERING_LIM = 0.61 # rad
def g(cte, phi):
error = phi + np.arctan(K_P*cte/FORWARD_VEL)
steer = np.clip(error, -STEERING_LIM, STEERING_LIM)
return (steer,)
def f(x, y, theta, steer):
new_x = x + FORWARD_VEL*np.cos(theta+steer)*CYCLE_TIME
new_y = y + FORWARD_VEL*np.sin(theta+steer)*CYCLE_TIME
new_theta = theta - FORWARD_VEL*np.sin(steer)/WHEEL_BASE*CYCLE_TIME
return new_x, new_y, new_theta
def v(x, y, theta) -> float:
return y**2 + theta**2
def pred(sample) -> bool:
x, y, theta, d, phi = sample
return v(*f(x, y, theta, *g(d, phi))) <= v(x, y, theta)
in_file_name = "collect_images_2021-11-22-17-59-46.cs598.pickle"
out_file_name = "collect_images_2021-11-22-17-59-46.cs598.filtered.pickle"
with open(in_file_name, "rb") as in_file:
pkl_data = pickle.load(in_file)
truth_samples_seq = pkl_data["truth_samples"]
filtered_truth_samples = []
for truth, samples in truth_samples_seq:
cte, phi = truth
filtered_samples = [s for s in samples if pred(s)]
if abs(len(samples) - len(filtered_samples)) < 20:
print("#Original:", len(samples), "#Filtered:", len(filtered_samples))
filtered_truth_samples.append((truth, samples))
else:
continue
with open(out_file_name, "wb") as out_file:
pkl_data["truth_samples"] = filtered_truth_samples
pickle.dump(pkl_data, out_file)
......@@ -24,11 +24,14 @@ from gem_stanley_teacher import GEMStanleyTeacher as Teacher
def test_synth_region():
pickle_file_io = open("data/collect_images_2021-11-22-17-59-46.cs598.pickle", "rb")
pickle_file_io = open("data/collect_images_2021-11-22-17-59-46.cs598.filtered.pickle", "rb")
pkl_data = pickle.load(pickle_file_io)
truth_samples_seq = pkl_data["truth_samples"]
i_th = 0 # select only the i-th partition
truth_samples_seq = truth_samples_seq[i_th:i_th+1]
truth_samples_seq = [(t, [s for s in raw_samples if not any(np.isnan(s))])
for t, raw_samples in truth_samples_seq]
# Chiao: Read in positive samples
......@@ -37,6 +40,7 @@ def test_synth_region():
]
ex_dim = len(positive_examples[0])
print("#examples: %d" % len(positive_examples))
print("Dimension of each example: %d" % ex_dim)
assert all(len(ex) == ex_dim and not any(np.isnan(ex))
for ex in positive_examples)
......@@ -46,8 +50,8 @@ def test_synth_region():
# x >= 6.0 and y >= 5
# teacher.set_old_state_bound(lb=[6.0, 5.0], ub=[11.0, 10.0])
#[::20]
synth_region(positive_examples[20:40:], teacher, num_max_iterations=20)
synth_region(positive_examples, teacher, num_max_iterations=20)
#Gurobi encoding
# 0 is diamond
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment