Skip to content
Snippets Groups Projects
scatter.rs 8.18 KiB
Newer Older
whooie's avatar
whooie committed
use thiserror::Error;
use crate::{
whooie's avatar
whooie committed
        Atom,
        AtomIter,
whooie's avatar
whooie committed
        RadiationPattern,
        State,
whooie's avatar
whooie committed
    },
    newton::{
        NewtonError,
whooie's avatar
whooie committed
        PhaseSpace,
        rka,
    },
    trap::Trap,
};

#[derive(Error, Debug)]
pub enum ScatterError {
    #[error("newton error {0}")]
    NewtonError(NewtonError),
    #[error("StateHistory::new: components have unequal lengths")]
    StateHistoryUnequalLength,
    #[error("Trajectory::new: components have unequal lengths")]
    TrajectoryUnequalLength,
whooie's avatar
whooie committed
}
pub type ScatterResult<T> = Result<T, ScatterError>;

/// Simple data struct to hold a time series of states.
#[derive(Clone, Debug, Default)]
pub struct StateHistory<S>
where S: State
whooie's avatar
whooie committed
{
whooie's avatar
whooie committed
}

whooie's avatar
whooie committed
{
    /// Construct a new `StateHistory`. Fails if the two series are of unequal
    /// lengths.
    pub fn new(t: Vec<f64>, s: Vec<S>) -> ScatterResult<Self> {
        return if t.len() == s.len() {
            Ok(Self { t, s })
        } else {
            Err(ScatterError::StateHistoryUnequalLength)
        };
    }

    pub fn len(&self) -> usize { self.t.len() }

    pub fn is_empty(&self) -> bool { self.t.len() == 0 }

    /// Push a new data point onto the end of the series.
    pub fn push(&mut self, t: f64, s: S) {
        self.t.push(t);
        self.s.push(s);
    }

    /// Return an iterator over `(time, state)` pairs.
    pub fn iter(&self) -> StateHistoryIter<S> {
        return StateHistoryIter { state_history: self, idx: 0 };
    }
}

impl<S> FromIterator<(f64, S)> for StateHistory<S>
where S: State
{
    fn from_iter<I>(iter: I) -> Self
    where I: IntoIterator<Item = (f64, S)>
    {
        let (t, s): (Vec<f64>, Vec<S>) = iter.into_iter().unzip();
        return Self { t, s };
    }
}

impl<S> IntoIterator for StateHistory<S>
where S: State
{
    type Item = (f64, S);
    type IntoIter
        = std::iter::Zip<std::vec::IntoIter<f64>, std::vec::IntoIter<S>>;

    fn into_iter(self) -> Self::IntoIter {
        return self.t.into_iter().zip(self.s.into_iter());
    }
}

pub struct StateHistoryIter<'a, S>
where S: State
{
    state_history: &'a StateHistory<S>,
    idx: usize,
}

impl<'a, S> Iterator for StateHistoryIter<'a, S>
where S: State
{
    type Item = (&'a f64, &'a S);

    fn next(&mut self) -> Option<Self::Item> {
        if self.idx < self.state_history.len() {
            let ret
                = Some((
                    &self.state_history.t[self.idx],
                    &self.state_history.s[self.idx],
                ));
            self.idx += 1;
            return ret;
        } else {
            return None;
        };
    }
}

#[derive(Clone, Debug, Default)]
pub struct Trajectory
{
    t: Vec<f64>,
    q: Vec<PhaseSpace>,
}

/// Simple data struct to hold a time series of `PhaseSpace` vectors.
impl Trajectory {
    /// Construct a new `Trajectory`. Fails if the two series are of unequal
    /// lengths.
    pub fn new(t: Vec<f64>, q: Vec<PhaseSpace>) -> ScatterResult<Self> {
        return if t.len() == q.len() {
            Ok(Self { t, q })
        } else {
            Err(ScatterError::TrajectoryUnequalLength)
        };
    }

    /// Push a new data point onto the end of the series.
    pub fn push(&mut self, t: f64, q: PhaseSpace) {
        self.t.push(t);
        self.q.push(q);
    }

    /// Move time series data into `self`.
    pub fn append(&mut self, t: &mut Vec<f64>, q: &mut Vec<PhaseSpace>) {
        self.t.append(t);
        self.q.append(q);
    }

    pub fn len(&self) -> usize { self.t.len() }

    pub fn is_empty(&self) -> bool { self.t.len() == 0 }

    /// Return an iterator over `(time, phasespace)` pairs.
    pub fn iter(&self) -> TrajectoryIter {
        return TrajectoryIter { trajectory: self, idx: 0 };
    }
}

impl IntoIterator for Trajectory {
    type Item = (f64, PhaseSpace);
    type IntoIter
        = std::iter::Zip<
            std::vec::IntoIter<f64>,
            std::vec::IntoIter<PhaseSpace>,
        >;

    fn into_iter(self) -> Self::IntoIter {
        return self.t.into_iter().zip(self.q.into_iter());
    }
}

impl FromIterator<(f64, PhaseSpace)> for Trajectory {
    fn from_iter<I>(iter: I) -> Self
    where I: IntoIterator<Item = (f64, PhaseSpace)>
    {
        let (t, q): (Vec<f64>, Vec<PhaseSpace>) = iter.into_iter().unzip();
        return Self { t, q };
    }
}

pub struct TrajectoryIter<'a> {
    trajectory: &'a Trajectory,
    idx: usize,
}

impl<'a> Iterator for TrajectoryIter<'a> {
    type Item = (&'a f64, &'a PhaseSpace);

    fn next(&mut self) -> Option<Self::Item> {
        if self.idx < self.trajectory.len() {
            let ret
                = Some((
                    &self.trajectory.t[self.idx],
                    &self.trajectory.q[self.idx],
                ));
            self.idx += 1;
            return ret;
        } else {
            return None
whooie's avatar
whooie committed
        }
    }
}

/// Simple data class holding both a state history and a phase-space trajectory.
#[derive(Clone, Debug, Default)]
pub struct ScatterHistory<S>
where S: State
{
    pub states: StateHistory<S>,
    pub traj: Trajectory,
}

impl<S> ScatterHistory<S>
where S: State
{
    /// Push a new data point onto the end of the state history.
    pub fn push_state(&mut self, t: f64, s: S) { self.states.push(t, s); }

    /// Push a new data point onto the end of the phase-space trajectory.
    pub fn push_traj(&mut self, t: f64, q: PhaseSpace) { self.traj.push(t, q); }
}

/// Calculate a single trajectory up to `t_final` *or* until the atom goes dark;
/// i.e. the final time in the trajectory is not guaranteed to be equal to
/// `t_final`.
pub fn scatter_sim<S, T, R>(atom: &Atom<S, T, R>, t_final: f64)
    -> ScatterResult<(ScatterHistory<S>, Atom<S, T, R>)>
whooie's avatar
whooie committed
where
    S: State,
    T: Trap,
    R: RadiationPattern,
{
    let m: f64 = atom.mass;
    let mut t_cur: f64 = 0.0;
    let mut state_history = StateHistory::new(vec![t_cur], vec![atom.state])?;
    let mut trajectory = Trajectory::new(Vec::new(), Vec::new())?;
    let mut atom_iter: AtomIter<S, T, R> = atom.to_state_iter();
    let (mut dt, mut t_int_end): (f64, f64);
    let mut tq_int: (Vec<f64>, Vec<PhaseSpace>);
    // simulate until t_final or the atom goes dark
    while let Some((state, photon_int)) = atom_iter.next() {
        dt
            = atom_iter.get_trap().period(atom.mass)
            .min(photon_int.time())
            .min(photon_int.time_mean())
            .min(t_final - t_cur)
            / 10.0;
        t_int_end = (t_cur + photon_int.time()).min(t_final);
        let int_rhs
            = |_t: f64, q: PhaseSpace| -> PhaseSpace {
                PhaseSpace {
                    pos: q.mom / m,
                    mom: -atom_iter.get_trap().gradient(q.pos),
                }
            };
        tq_int = rka((t_cur, t_int_end), atom_iter.q, dt, int_rhs, 1e-6)
            .map_err(ScatterError::NewtonError)?;
        t_cur = tq_int.0.pop().unwrap();
        atom_iter.q = tq_int.1.pop().unwrap();
        trajectory.append(&mut tq_int.0, &mut tq_int.1);
        state_history.push(t_cur, state);
        atom_iter.q.mom
            += photon_int.momentum_kick().unwrap_or(ThreeVector::zero());
whooie's avatar
whooie committed
        if t_cur >= t_final { break; }
    }
    return Ok((
        ScatterHistory { states: state_history, traj: trajectory },
        atom_iter.dump().0,
whooie's avatar
whooie committed
    ));
}

/// Calculate a single state history, disregarding position and momentum, up to
/// `t_final` *or* until the atom goes dark; i.e. the final time in the
/// trajectory is not guaranteed to be equal to `t_final`.
pub fn scatter_sim_static<S, T, R>(atom: &Atom<S, T, R>, t_final: f64)
    -> ScatterResult<(StateHistory<S>, Atom<S, T, R>)>
whooie's avatar
whooie committed
where
    S: State,
    T: Trap,
    R: RadiationPattern,
{
    let mut t_cur: f64 = 0.0;
    let mut state_history = StateHistory::new(vec![t_cur], vec![atom.state])?;
    let mut atom_iter: AtomIterStatic<S, T, R> = atom.to_state_iter_static();
    // simulate until t_final or the atom goes dark
    for (state, photon_int) in atom_iter.by_ref() {
        t_cur += photon_int.time();
        state_history.push(t_cur, state);
        if t_cur >= t_final { break; }
whooie's avatar
whooie committed
    }
    return Ok((state_history, atom_iter.dump()));
whooie's avatar
whooie committed
}