Skip to content
Snippets Groups Projects
Commit ac7b6e86 authored by Words Hua's avatar Words Hua
Browse files

Add Trajectory Analyzer

parent c2275528
No related branches found
No related tags found
No related merge requests found
import pandas as pd
import matplotlib.pyplot as plt
import os
import numpy as np
class SignalAnalyzer:
def __init__(self, folder_name, file_base_name):
"""
Initialize the SignalAnalyzer class with file paths and load the signal data.
Args:
folder_name: The folder where the signal files are located.
file_base_name: The base name of the signal file (without extension).
"""
self.folder_name = folder_name
self.file_base_name = file_base_name
self.file_path = os.path.join(r'C:\Users\Cs Egg-Eater\University of Illinois - Urbana\Covey Lab - Documents\Cs\ACSworkfile',
folder_name, file_base_name + '.csv')
self.data_signals = pd.read_csv(self.file_path, skiprows=4)
self.channel_abbr = self.data_signals.iloc[3, 1:].values
def ensure_commas(self, line, required_commas=5):
"""
Ensure that a line has at least the required number of commas.
If not, append additional commas to the line.
"""
comma_count = line.count(',')
if comma_count < required_commas:
line = line.strip() + ',' * (required_commas - comma_count)
return line + '\n'
return line
def reformat_file_with_commas(self):
"""
Reformat the loaded signal file to ensure proper comma structure.
"""
with open(self.file_path, 'r') as file:
lines = file.readlines()
required_commas = lines[6].count(',') # Count the commas in the 7th line
# Apply the rule to the first 5 lines (index 0 to 4)
for i in range(5):
lines[i] = self.ensure_commas(lines[i], required_commas)
# Save the reformatted file
with open(self.file_path, 'w') as output_file:
output_file.writelines(lines)
def load_and_plot_signal_general(self, cutoff=1, output_channel=[]):
"""
Generalized function to plot any number of signal channels with different colors and separate Y-axis scales.
Args:
cutoff: A percentage cutoff to limit the plotted data.
output_channel: A list of output channels to plot. If empty, all channels are plotted.
"""
file_name = os.path.splitext(os.path.basename(self.file_path))[0]
# Extract the actual signal data from row 4 onwards
signal_data_corrected = self.data_signals.iloc[4:, 1:].astype(float).values
# Time axis: sampling time is dynamically read from row 8 (row 3 in zero-indexing)
sampling_interval_ms = float(self.data_signals.iloc[2, 1]) # Read the sampling time from row 8 (column B)
num_samples = signal_data_corrected.shape[0] # Number of samples
time_array = [sampling_interval_ms * i for i in range(1, num_samples + 1)] # Time in ms
# Cut off the time array and signal data based on the cutoff percentage
time_array_cutoff = time_array[:int(cutoff * len(time_array))]
signal_data_corrected_cutoff = signal_data_corrected[:int(cutoff * len(time_array)), :]
time_array = time_array_cutoff
signal_data_corrected = signal_data_corrected_cutoff
if output_channel:
output_channel = [channel - 1 for channel in output_channel if 0 < channel <= signal_data_corrected.shape[1]]
signal_data_corrected = signal_data_corrected[:, output_channel]
# Define a list of distinct colors for the signals
colors = ['purple', 'blue', 'green', 'red', 'cyan', 'magenta', 'orange', 'pink', 'brown', 'yellow']
color_cycle = colors * ((signal_data_corrected.shape[1] // len(colors)) + 1)
fig, ax1 = plt.subplots(figsize=(16, 8))
lines = []
if output_channel:
line, = ax1.plot(time_array, signal_data_corrected[:, 0], '-', color=color_cycle[output_channel[0]], label=f'CH{1 + output_channel[0]}: {self.channel_abbr[output_channel[0]]}')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel(f'CH{1 + output_channel[0]}', color=color_cycle[output_channel[0]])
ax1.tick_params(axis='y', labelcolor=color_cycle[output_channel[0]])
else:
line, = ax1.plot(time_array, signal_data_corrected[:, 0], '-', color=color_cycle[0], label=f'CH1: {self.channel_abbr[0]}')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('CH1', color=color_cycle[0])
ax1.tick_params(axis='y', labelcolor=color_cycle[0])
lines.append(line)
# Plot remaining channels
previous_axis = ax1
num_channels = signal_data_corrected.shape[1]
for i in range(1, num_channels):
new_axis = previous_axis.twinx()
new_axis.spines['right'].set_position(('outward', 60 * (i-1)))
if output_channel:
line, = new_axis.plot(time_array, signal_data_corrected[:, i], '-', color=color_cycle[output_channel[i]], label=f'CH{output_channel[i] + 1}: {self.channel_abbr[output_channel[i]]}')
new_axis.set_ylabel(f'CH {output_channel[i] + 1}', color=color_cycle[output_channel[i]])
new_axis.tick_params(axis='y', labelcolor=color_cycle[output_channel[i]])
else:
line, = new_axis.plot(time_array, signal_data_corrected[:, i], '-', color=color_cycle[i], label=f'CH{i+1}: {self.channel_abbr[i]}')
new_axis.set_ylabel(f'CH {i+1}', color=color_cycle[i])
new_axis.tick_params(axis='y', labelcolor=color_cycle[i])
new_axis.spines['left'].set_visible(False)
new_axis.yaxis.set_label_position('right')
new_axis.spines['right'].set_visible(True)
new_axis.yaxis.tick_right()
lines.append(line)
ax1.legend(handles=lines, loc='upper left')
plt.title(f'{file_name}')
fig.tight_layout()
plt.show()
def signal_statistics_analysis(self):
"""
Perform basic statistical analysis on each signal channel.
"""
signal_data_corrected = self.data_signals.iloc[4:, 1:].astype(float).values
sampling_interval_ms = float(self.data_signals.iloc[2, 1])
num_samples = signal_data_corrected.shape[0]
time_array = [sampling_interval_ms * i for i in range(1, num_samples + 1)]
for i in range(signal_data_corrected.shape[1]):
channel_data = signal_data_corrected[:, i]
max_value = channel_data.max()
max_position = time_array[channel_data.argmax()]
min_value = channel_data.min()
min_position = time_array[channel_data.argmin()]
magnitude_of_change = max_value - min_value
average_value = channel_data.mean()
std_deviation = channel_data.std()
print(f'Statistics for Channel {i+1} ({self.channel_abbr[i]}):')
print(f' Maximum Value: {max_value} @ {max_position:.2f} ms')
print(f' Minimum Value: {min_value} @ {min_position:.2f} ms')
print(f' Magnitude of Change: {magnitude_of_change}')
print(f' Average: {average_value}')
print(f' Standard Deviation: {std_deviation}')
print()
def analyze_trajectory_in_range(self, target_position_L, target_position_R):
"""
Analyze trajectory segments where the signal is within the target position range.
Args:
target_position_L: Left boundary of the target position range.
target_position_R: Right boundary of the target position range.
"""
trajectory_channel_idx = self.find_channel_index('FPOS')
PE_channel_idx = self.find_channel_index('PE')
VEL_channel_idx = self.find_channel_index('FVEL')
RMSM_channel_idx = self.find_channel_index('RMSM')
trajectory_data = self.data_signals.iloc[4:, trajectory_channel_idx+1].astype(float).values
PE_data = self.data_signals.iloc[4:, PE_channel_idx+1].astype(float).values
VEL_data = self.data_signals.iloc[4:, VEL_channel_idx+1].astype(float).values
RMSM_data = self.data_signals.iloc[4:, RMSM_channel_idx+1].astype(float).values
periods_df = self.find_position_period(trajectory_data, target_position_L, target_position_R)
idx_list = []
magnitude_PE_list = []
max_RMSM_list = []
for idx, row in periods_df.iterrows():
segment_PE = PE_data[row['Start_Index']:row['End_Index']+1]
segment_RMSM = RMSM_data[row['Start_Index']:row['End_Index']+1]
magnitude_PE = segment_PE.max() - segment_PE.min()
max_RMSM = segment_RMSM.max()
# Collect data for plotting
idx_list.append(idx)
magnitude_PE_list.append(magnitude_PE)
max_RMSM_list.append(max_RMSM)
# Plot idx vs. magnitude of PE
self._plot(idx_list, magnitude_PE_list, 'Segment Index', 'Magnitude of PE', 'Segment Index vs. Magnitude of PE')
# Plot idx vs. maximum value of RMSM
self._plot(idx_list, max_RMSM_list, 'Segment Index', 'Max RMSM', 'Segment Index vs. Max RMSM')
# Plot magnitude of PE vs. max RMSM
self._plot(magnitude_PE_list, max_RMSM_list, 'Magnitude of PE', 'Max RMSM', 'Magnitude of PE vs. Max RMSM')
# Draw phase plot of trajectory
self.draw_phase_plot(trajectory_data, VEL_data, target_position_L, target_position_R)
def find_channel_index(self, abbreviation):
"""
Find the channel index by abbreviation (e.g., 'FPOS', 'PE', 'FVEL', 'RMSM').
Args:
abbreviation: The abbreviation of the channel.
Returns:
Index of the channel or raises an error if not found.
"""
for i, abbr in enumerate(self.channel_abbr):
if abbr == abbreviation:
return i
raise ValueError(f"Channel '{abbreviation}' not found in the data.")
def find_position_period(self, trajectory_data, target_position_L, target_position_R):
"""
Find the periods when the trajectory is within the specified target range.
Args:
trajectory_data: The array of trajectory data.
target_position_L: Left boundary of the target position range.
target_position_R: Right boundary of the target position range.
Returns:
DataFrame containing the start and end indices for each period.
"""
within_range = (trajectory_data >= target_position_L) & (trajectory_data <= target_position_R)
periods = []
start_idx = None
for i, in_range in enumerate(within_range):
if in_range and start_idx is None:
start_idx = i
elif not in_range and start_idx is not None:
periods.append((start_idx, i-1))
start_idx = None
if start_idx is not None:
periods.append((start_idx, len(within_range)-1))
return pd.DataFrame(periods, columns=['Start_Index', 'End_Index'])
def draw_phase_plot(self, position_data, velocity_data, target_position_L, target_position_R):
"""
Draw the phase plot of position vs velocity with the target range.
Args:
position_data: Array of position data.
velocity_data: Array of velocity data.
target_position_L: Left boundary of the target position range.
target_position_R: Right boundary of the target position range.
"""
plt.figure(figsize=(12, 6))
plt.plot(position_data, velocity_data, 'b-', label='Trajectory')
plt.axvline(x=target_position_L, color='r', linestyle='--', label='Target Position L')
plt.axvline(x=target_position_R, color='g', linestyle='--', label='Target Position R')
plt.xlabel('Position')
plt.ylabel('Velocity')
plt.title('Phase Plot: Position vs. Velocity')
plt.legend()
plt.grid(True)
plt.show()
def _plot(self, x_data, y_data, x_label, y_label, title):
"""
Utility function to create a plot.
Args:
x_data: Data for the x-axis.
y_data: Data for the y-axis.
x_label: Label for the x-axis.
y_label: Label for the y-axis.
title: Title of the plot.
"""
plt.figure(figsize=(10, 5))
plt.plot(x_data, y_data, 'o-', label=title)
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.title(title)
plt.grid(True)
plt.legend()
plt.show()
# main.py
from simulation import Simulation
from visualization import plot_trajectories, plot_histogram
from optimization import optimize_parameters
from geometry import D_CELL, L_WIDE, L_GLASS, L_CELL
import matplotlib.pyplot as plt
def run_simulation():
# Define initial parameters
x_tube = (28.194 + 45.6184) / 2 * 1e-3 # m
y_tube = 13.8 * 1e-3 # m
temperature_x = 20197e-6 # K
temperature_y = 5e-6 # K
v0x_center = 4.40 # m/s
v0y_center = 0.57 # m/s
N = 10000
# Create simulation
sim = Simulation(tube_start=(x_tube, y_tube), temp=(temperature_x, temperature_y))
# Sample velocities
v0xs, v0ys = sim.sample_velocities(v0x_center, v0y_center, N)
# Filter velocities that will pass through the tube
v0xs, v0ys = sim.filter_velocities(v0xs, v0ys)
# Count remaining valid velocities
print(f"Number of remaining velocities: {v0xs.shape[0]}/{N}")
# Measure x crossings of y = y_tube
other_x_crossings, x_crossings = sim.measure_xs(sim.y_tube, v0xs, v0ys)
# Plot results
plt.figure(figsize=(10, 5))
# Define position limits
pos_scale = 1e3 # mm/m
xlim = (-0.005, sim.x_tube + L_WIDE + L_GLASS + L_CELL + 0.01)
ylim = (-0.005, sim.y_tube + D_CELL / 2 + 0.01)
# Plot histogram of x-crossings
plot_histogram(x_crossings, N, pos_scale)
# optimization.py
from simulation import Simulation
import numpy as np
from scipy.optimize import minimize
def objective(params, N, temperature_x, temperature_y):
v0x_center, v0y_center, x_tube_rescale, y_tube_rescale = params
x_tube = x_tube_rescale / 100
y_tube = y_tube_rescale / 100
sim = Simulation(tube_start=(x_tube, y_tube), temp=(temperature_x, temperature_y))
v0xs, v0ys = sim.sample_velocities(v0x_center, v0y_center, N)
v0xs, v0ys = sim.filter_velocities(v0xs, v0ys)
other_x_crossings, x_crossings = sim.measure_xs(y_tube, v0xs, v0ys)
relative_x_crossings = x_crossings - (sim.x_tube + L_WIDE + L_GLASS)
bin_counts, bin_edges = np.histogram(relative_x_crossings, bins=50)
total_probability = 100 * np.sum(bin_counts) / N
return -total_probability
def optimize_parameters(N, initial_guess, bounds):
result = minimize(objective, initial_guess, bounds=bounds, args=(N,))
return result.x
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment