#![allow(unused_imports)]

//! Provides definitions related to internal atomic dynamics and how they couple
//! with incoming and outgoing photons.

use std::{
    collections::HashMap,
    f64::consts::{
        PI,
        TAU,
    },
    hash::Hash,
    fmt::Debug,
};
use num_complex::Complex64 as C64;
use rand::{
    prelude as rnd,
    Rng,
    distributions::Distribution,
};
use statrs::distribution::Exp;
use thiserror::Error;
use wigner_symbols::Wigner3jm;
use crate::{
    newton::{
        ThreeVector,
        PhaseSpace,
    },
    phys::{
        h,
        hbar,
    },
    trap::Trap,
};

#[derive(Error, Debug)]
pub enum AtomError {
    #[error("reached dark state")]
    DarkState,
    #[error("trap missing for state {0}")]
    TrapUndefined(String),
}
pub type AtomResult<T> = Result<T, AtomError>;

/// Population in the excited state of a two-level system after all transient
/// oscillations have been damped away.
/// ```math
/// \rho_\text{ee}(s, \Delta, \Gamma)
///     = \frac{1}{2} \frac{s}{1 + 2 + (2 \Delta / \Gamma)^2}
/// ```
pub fn pop_excited(saturation: f64, detuning: f64, linewidth: f64) -> f64 {
    return
        saturation / 2.0
        / (1.0 + saturation + (2.0 * detuning / linewidth).powi(2))
        ;
}

/// Full time-dependent population in the excited state of a two-level system.
/// Detuning and linewidth should be given in angular units.
/// ```math
/// \begin{aligned}
///     \rho_{ee}(t; s, \Delta, \Gamma)
///         &= \rho_0 \left(
///             1 - e^{-\frac{3}{4} \Gamma t} \cos(\omega t)
///         \right)
///     \\
///     \rho
///         &= \frac{1}{2} \frac{s}{1 + s + (2 \Delta / \Gamma)^2}
///     \\
///     \omega
///         &= \sqrt{\Omega^2 - \left(\frac{\Gamma}{4}\right)^2 + \Delta^2}
///     \\
///     \Omega
///         &= \sqrt{\frac{\Gamma}{2} s}
/// \end{aligned}
/// ```
pub fn rho_ee(t: f64, saturation: f64, detuning: f64, linewidth: f64) -> f64 {
    let W: f64 = (linewidth / 2.0 * saturation).sqrt();
    let w: f64
        = (W.powi(2) - (linewidth / 4.0).powi(2) + detuning.powi(2)).sqrt();
    return
        pop_excited(saturation, detuning, linewidth)
        * (1.0 - (-0.75 * linewidth * t).exp() * (w * t).cos())
        ;
}

/// First derivative of the full time-dependent population in the excited state
/// of a two-level system. Detuning and linewidth should be given in angular
/// units.
/// ```math
/// \begin{aligned}
///     \dot{\rho}_{ee}(t; s, \Delta, \Gamma)
///         &= \rho_0 \left(
///             \frac{3}{4} \Gamma e^{-\frac{3}{4} \Gamma t} \cos(\omega t)
///             + \omega e^{-\frac{3}{4} \Gamma t} \sin(\omega t)
///         \right)
///     \\
///     \rho
///         &= \frac{1}{2} \frac{s}{1 + s + (2 \Delta / \Gamma)^2}
///     \\
///     \omega
///         &= \sqrt{\Omega^2 - \left(\frac{\Gamma}{4}\right)^2 + \Delta^2}
///     \\
///     \Omega
///         &= \sqrt{\frac{\Gamma}{2} s}
/// \end{aligned}
/// ```
pub fn drho_ee(t: f64, saturation: f64, detuning: f64, linewidth: f64) -> f64 {
    let W: f64 = (linewidth / 2.0 * saturation).sqrt();
    let w: f64
        = (W.powi(2) - (linewidth / 4.0).powi(2) + detuning.powi(2)).sqrt();
    return
        pop_excited(saturation, detuning, linewidth)
        * (
            0.75 * linewidth * (-0.75 * linewidth * t).exp() * (w * t).cos()
            + w * (-0.75 * linewidth * t).exp() * (w * t).sin()
        )
        ;
}

/// First maximum $`(t_0, \rho_0)`$ of the full time-dependent population in the
/// excited state of a two-level system. Detuning and linewidth should be given
/// in angular units.
/// ```math
/// \begin{aligned}
///     t_0
///         &= \frac{2}{\omega} \arctan\left(
///             \frac{4 \omega}{3 \Gamma}
///             + \sqrt{1 + \left(\frac{4 \omega}{3 \Gamma}\right)^2}
///         \right)
///     \\
///     \rho_0(s, \Delta, \Gamma)
///         &= \rho_{ee}(t_0; s, \Delta, \Gamma)
///     \\
///     \rho_{ee}(t; s, \Delta, \Gamma)
///         &= \rho_0 \left(
///             1 - e^{-\frac{3}{4} \Gamma t} \cos(\omega t)
///         \right)
///     \\
///     \rho
///         &= \frac{1}{2} \frac{s}{1 + s + (2 \Delta / \Gamma)^2}
///     \\
///     \omega
///         &= \sqrt{\Omega^2 - \left(\frac{\Gamma}{4}\right)^2 + \Delta^2}
///     \\
///     \Omega
///         &= \sqrt{\frac{\Gamma}{2} s}
/// \end{aligned}
/// ```
pub fn rho_ee_max(saturation: f64, detuning: f64, linewidth: f64) -> (f64, f64)
{
    let W: f64 = (linewidth / 2.0 * saturation).sqrt();
    let w: f64
        = (W.powi(2) - (linewidth / 4.0).powi(2) + detuning.powi(2)).sqrt();
    let t0: f64
        = 2.0 / w
        * (
            4.0 * w / 3.0 / linewidth
            + (1.0 + (4.0 * w / 3.0 / linewidth).powi(2)).sqrt()
        ).atan();
    let rho0: f64
        = pop_excited(saturation, detuning, linewidth)
        * (1.0 - (-0.75 * linewidth * t0).exp() * (w * t0).cos());
    return (t0, rho0);
}

/// Inverts the [$`\rho_{ee}`$][rho_ee] function via Newton-Raphson for inverse
/// transform sampling for a given probability value $`r`$. Returns
/// `f64::INFINITY` if $`r`$ is greater than the maximum of the function
/// [$`\rho_0`$][rho_ee_max], and panics if $`r \not\in [0, 1]`$ or if the
/// method fails to converge. Detuning and linewidth should be given in angular
/// units.
pub fn rho_ee_inv(saturation: f64, detuning: f64, linewidth: f64, r: f64)
    -> f64
{
    if !(0.0..=1.0).contains(&r) {
        panic!("rho_ee_inv: encountered invalid probability value");
    }
    let (t0, rho0): (f64, f64) = rho_ee_max(saturation, detuning, linewidth);
    if r > rho0 {
        return f64::INFINITY;
    }
    let mut t: f64 = t0 / 2.0;
    let mut dt: f64;
    for _ in 0..1000 {
        dt
            = (rho_ee(t, saturation, detuning, linewidth) - r)
            / drho_ee(t, saturation, detuning, linewidth);
        t -= dt;
        if dt.abs() < 1e-6 {
            return t;
        }
        t = t.max(0.0).min(t0);
    }
    panic!("rho_ee_inv: failed to converge");
}

/// Computes the mean excitation time using the (properly normalized)
/// [$`\rho_{ee}`$][rho_ee] function as a cumulative distribution function
/// defined from $`t = 0`$ to [$`t_0`$][rho_ee_max]. Detuning and linewidth
/// should be given in angular units.
pub fn rho_ee_mean_time(saturation: f64, detuning: f64, linewidth: f64) -> f64 {
    let W: f64 = (linewidth / 2.0 * saturation).sqrt();
    let w: f64
        = (W.powi(2) - (linewidth / 4.0).powi(2) + detuning.powi(2)).sqrt();
    let (t0, _): (f64, f64) = rho_ee_max(saturation, detuning, linewidth);
    // recurring constants
    let K1: f64 = (0.75 * linewidth * t0).exp() - (w * t0).cos();
    let K2: f64 = 16.0 * w.powi(2) + 9.0 * linewidth.powi(2);
    return
        (
            12.0 * linewidth * K1
            - K2 * t0 * (w * t0).cos()
            + 16.0 * w * (w * t0).sin()
        ) / ( K1 * K2 );
}

/// Absolute value of the Clebsch-Gordan coefficient coupling two spin states
/// $`|F^0, m_F^0\rangle`$ and $`|F^1, m_F^1\rangle`$ with a photon.
///
/// Calculated from the appropriate Wigner 3-$`j`$ symbol,
/// ```math
/// \left|
///     \sqrt{2 F^0 + 1}
///     \begin{pmatrix}
///         F^1   & 1             & F^0
///         \\
///         m_F^1 & m_F^0 - m_F^1 & -m_F^0
///     \end{pmatrix}
/// \right|
/// ```
pub fn cg(spin0: (f64, f64), spin1: (f64, f64)) -> f64
{
    let (F0, mF0): (f64, f64) = spin0;
    let (F1, mF1): (f64, f64) = spin1;
    let wig = f64::from(
        Wigner3jm { // elements are represented as the number of halves
            tj1: (2.0 * F1) as i32,
            tm1: (2.0 * mF1) as i32,
            tj2: 2_i32,
            tm2: (2.0 * (mF0 - mF1)) as i32,
            tj3: (2.0 * F0) as i32,
            tm3: (-2.0 * mF0) as i32,
        }.value()
    );
    return (2.0 * F0 + 1.0).sqrt() * wig.abs();
}

/// A distribution over spherical angles $`\theta, \phi`$ describing the
/// likelihood of a given photon being radiated in that direction.
///
/// $`\phi`$ is the azimuthal angle.
pub trait RadiationPattern: Copy + Clone {
    /// Sample angles $`\theta, \phi`$ where $`\phi`$ is the azimuthal angle.
    fn sample_angles_rng<R>(&self, rng: &mut R) -> (f64, f64)
    where R: Rng + ?Sized;

    /// Sample angles $`\theta, \phi`$ where $`\phi`$ is the azimuthal angle.
    fn sample_angles(&self) -> (f64, f64) {
        let mut rng = rnd::thread_rng();
        return self.sample_angles_rng(&mut rng);
    }

    /// Sample the momentum vector of the photon, given a wavenumber (angular
    /// units). Default implementation assumes SI units; override for different
    /// unit systems.
    fn sample_momentum_rng<R>(&self, k: f64, rng: &mut R) -> ThreeVector
    where R: Rng + ?Sized
    {
        let (th, ph): (f64, f64) = self.sample_angles_rng(rng);
        return ThreeVector::from_angles(hbar * k, th, ph);
    }

    /// Sample the momentum vector of the photon, given a wavenumber (angular
    /// units). Default implementation same units as `self.sample_momentum_rng`.
    fn sample_momentum(&self, k: f64) -> ThreeVector {
        let mut rng = rnd::thread_rng();
        return self.sample_momentum_rng(k, &mut rng);
    }

    /// Sample the corresponding momentum kick applied to the atom, given a
    /// photon wavenumber (angular units). Default implementation assumes same
    /// units as `self.sample_momentum_rng`.
    fn sample_momentum_kick_rng<R>(&self, k: f64, rng: &mut R) -> ThreeVector
    where R: Rng + ?Sized
    {
        return -self.sample_momentum_rng(k, rng);
    }

    /// Sample the corresponding momentum kick applied to the atom, given a
    /// photon wavenumber (angular units). Default implementation assumes same
    /// units as `self.sample_momentum_rng`.
    fn sample_momentum_kick(&self, k: f64) -> ThreeVector {
        return -self.sample_momentum(k);
    }
}

/// Radiation pattern following a uniform distribution over the sphere.
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub struct RadUniform { }

impl RadUniform {
    pub fn new() -> Self { Self { } }
}

impl RadiationPattern for RadUniform {
    fn sample_angles_rng<R>(&self, rng: &mut R) -> (f64, f64)
    where R: Rng + ?Sized
    {
        return (
            (1.0 - 2.0 * rng.gen::<f64>()).acos(),
            rng.gen::<f64>() * TAU,
        );
    }
}

/// Radiation pattern following that from an electric dipole, with quantization
/// axis fixed along $`z`$.
///
/// This pattern is parameterized by the relative proportion of light being
/// radiated in a $`\pi`$ (linear)-polarized mode compared to that in a
/// $`\sigma`$ (circular) mode, following $`\sin^2 \theta`$ and $`1 + \cos^2
/// \theta`$ distributions, respectively.
#[derive(Copy, Clone, Debug)]
pub struct RadDipole {
    pi: f64,
    sigma: f64,
}

impl RadDipole {
    /// Give the relative proportion of radiation scattered into
    /// $`\pi`$-polarized (oscillating dipole) and $`\sigma`$-polarized
    /// (rotating dipole) modes. The oscillation axis is fixed to $`z`$.
    /// Proportions will be automatically normalized.
    pub fn new(pi: f64, sigma: f64) -> Self {
        let pi_norm: f64 = pi / (pi + sigma);
        let sigma_norm: f64 = sigma / (pi + sigma);
        return Self { pi: pi_norm, sigma: sigma_norm };
    }

    /// Special case for which the distribution is uniform.
    pub fn uniform() -> Self {
        return Self::new(1.0_f64 / 3.0, 2.0_f64 / 3.0);
    }

    fn pdf_theta(&self, theta: f64) -> f64 {
        return
            self.pi * 3.0 / 8.0 / PI * theta.sin().powi(2)
            + self.sigma * 3.0 / 16.0 / PI * (1.0 + theta.cos().powi(2))
            ;
    }

    fn cdf_theta(&self, theta: f64) -> f64 {
        return
            0.5
            - (2.0 * self.pi + self.sigma) * 3.0 / 8.0 * theta.cos()
            + (2.0 * self.pi - self.sigma) / 8.0 * theta.cos().powi(3)
            ;
    }

    /// *Panics if `r` is not between 0 and 1 or if the method fails to
    /// converge.
    fn cdf_inv_theta(&self, r: f64) -> f64 {
        // let mut th: f64 = (r * PI).max(1e-6).max((1.0 - 1e-6) * PI);
        let mut th: f64 = PI / 2.0;
        let mut dth: f64;
        for _ in 0..1000 {
            dth
                = (self.cdf_theta(th) - r)
                / (self.pdf_theta(th) * TAU * th.sin());
            th -= dth;
            if dth.abs() < 1e-6 {
                return th;
            }
            th = th.max(1e-6).min((1.0 - 1e-6) * PI);
        }
        panic!("RadDipole::cdf_inv_theta: failed to converge");
    }
}

impl RadiationPattern for RadDipole {
    fn sample_angles_rng<R>(&self, rng: &mut R) -> (f64, f64)
    where R: Rng + ?Sized
    {
        return (
            self.cdf_inv_theta(rng.gen::<f64>()),
            rng.gen::<f64>() * TAU,
        );
    }
}

/// Laser beam parameterization. Assumed to be centered on the origin.
#[derive(Copy, Clone, Debug)]
pub struct Laser {
    pub saturation: f64,
    /// 1/e^2 radius; distance
    pub radius: f64,
    /// Relative to the free-space limit; time^-1 (non-angular)
    pub detuning: f64,
    /// Units of mass.distance.time^-1
    pub momentum: ThreeVector,
    // /// Units of distance^-1 (angular)
    // pub wavevector: ThreeVector,
}

impl Laser {
    /// Find the perpendicular distance between a given position and the laser's
    /// momentum (assumed to be centered on the origin).
    pub fn perp_dist(&self, r: ThreeVector) -> f64 {
        return (
            (
                - (self.momentum.1.powi(2) + self.momentum.2.powi(2)) * r.0
                + self.momentum.0 * self.momentum.1 * r.1
                + self.momentum.0 * self.momentum.2 * r.2
            ).powi(2)
            + (
                self.momentum.0 * self.momentum.1 * r.0
                - (self.momentum.0.powi(2) + self.momentum.2.powi(2)) * r.1
                + self.momentum.1 * self.momentum.2 * r.2
            ).powi(2)
            + (
                self.momentum.0 * self.momentum.2 * r.0
                + self.momentum.1 * self.momentum.2 * r.1
                - (self.momentum.0.powi(2) + self.momentum.1.powi(2)) * r.2
            ).powi(2)
        ).sqrt() / self.momentum.norm().powi(2);
    }
}

/// Describes a state with total and projectional spin quantum numbers.
pub trait State: Copy + Clone + Debug + PartialEq + Eq + Hash {
    /// Total spin quantum number.
    fn spin_total(&self) -> f64;

    /// Projectional spin quantum number.
    fn spin_proj(&self) -> f64;

    /// Return both quantum numbers in a tuple, with the total spin number
    /// listed first
    fn spin(&self) -> (f64, f64) { (self.spin_total(), self.spin_proj()) }

    /// Calculate the (absolute value of the) photon-coupling Clebsch-Gordan
    /// coefficient between two states.
    fn cg<S>(&self, other: &S) -> f64
    where S: State
    {
        return cg(self.spin(), other.spin());
    }

    /// Calculate the square of the photon-coupling Clebsch-Gordan coefficient
    /// between two states.
    fn cg_sq<S>(&self, other: &S) -> f64
    where S: State
    {
        return cg(self.spin(), other.spin()).powi(2);
    }
}

/// Simple token identifying the kind of a transition.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum TransitionKind {
    Exciting = 0,
    Decaying = 1,
}

/// Photon absorption data.
#[derive(Copy, Clone, Debug)]
pub struct Absorption {
    /// May be infinite, indicating that the photon missed the atom.
    pub excite_time: f64,
    pub excite_time_mean: f64,
    pub momentum_kick: ThreeVector,
}

/// Photon radiation data.
#[derive(Copy, Clone, Debug)]
pub struct Radiation
{
    pub decay_time: f64,
    pub decay_time_mean: f64,
    pub momentum_kick: ThreeVector,
}

/// Thin wrapper enum to describe a particular (i.e. sampled) photon interaction
/// with an atom.
#[derive(Copy, Clone, Debug)]
pub enum PhotonInteraction
{
    Absorption(Absorption),
    Radiation(Radiation),
}

impl PhotonInteraction {
    /// Returns the excitation time if `self` is an absorption.
    pub fn excite_time(&self) -> Option<f64> {
        return match self {
            Self::Absorption(a) => Some(a.excite_time),
            Self::Radiation(_) => None,
        };
    }

    /// Returns the mean excitation time if `self` is an absorption.
    pub fn excite_time_mean(&self) -> Option<f64> {
        return match self {
            Self::Absorption(a) => Some(a.excite_time_mean),
            Self::Radiation(_) => None,
        };
    }

    /// Returns the decay time if `self` is a radiation.
    pub fn decay_time(&self) -> Option<f64> {
        return match self {
            Self::Absorption(_) => None,
            Self::Radiation(r) => Some(r.decay_time),
        };
    }

    /// Returns the mean decay time if `self` is a radiation.
    pub fn decay_time_mean(&self) -> Option<f64> {
        return match self {
            Self::Absorption(_) => None,
            Self::Radiation(r) => Some(r.decay_time_mean),
        };
    }

    /// Returns the transition time, regardless of interaction type.
    pub fn transition_time(&self) -> f64 {
        return match self {
            Self::Absorption(a) => a.excite_time,
            Self::Radiation(r) => r.decay_time,
        };
    }

    /// Returns the mean transition time, regardless of interaction type.
    pub fn transition_time_mean(&self) -> f64 {
        return match self {
            Self::Absorption(a) => a.excite_time_mean,
            Self::Radiation(r) => r.decay_time_mean,
        };
    }

    /// Get the momentum kick to apply to the atom.
    pub fn momentum_kick(&self) -> ThreeVector {
        return match self {
            Self::Absorption(a) => a.momentum_kick,
            Self::Radiation(r) => r.momentum_kick,
        };
    }
}

/// Holds parameters for a specific transition that can occur between two atomic
/// states.
#[derive(Copy, Clone, Debug)]
pub enum Transition<S, R>
where
    S: State,
    R: RadiationPattern,
{
    Exciting {
        ground: S,
        excited: S,
        /// Units of distance
        wavelength: f64,
        /// Units of frequency (non-angular)
        linewidth: f64,
        laser: Laser,
    },
    Decaying {
        ground: S,
        excited: S,
        /// Units of distance
        wavelength: f64,
        /// Units of frequency (non-angular)
        linewidth: f64,
        radiation: R,
    }
}

impl<S, R> Transition<S, R>
where
    S: State,
    R: RadiationPattern,
{
    /// Create a new exciting transition.
    pub fn new_exciting(
        ground: S,
        excited: S,
        wavelength: f64,
        linewidth: f64,
        laser: Laser,
    ) -> Self
    {
        return Self::Exciting {
            ground,
            excited,
            wavelength,
            linewidth,
            laser
        };
    }

    /// Create a new decaying transition.
    pub fn new_decaying(
        ground: S,
        excited: S,
        wavelength: f64,
        linewidth: f64,
        radiation: R,
    ) -> Self
    {
        return Self::Decaying {
            ground,
            excited,
            wavelength,
            linewidth,
            radiation
        };
    }

    /// Get the kind of the transition.
    pub fn kind(&self) -> TransitionKind {
        return match self {
            Self::Exciting { .. } => TransitionKind::Exciting,
            Self::Decaying { .. } => TransitionKind::Decaying,
        };
    }

    pub fn is_exciting(&self) -> bool {
        return matches!(self, Self::Exciting { .. });
    }

    pub fn is_decaying(&self) -> bool {
        return matches!(self, Self::Decaying { .. });
    }

    pub fn ground_state(&self) -> S {
        return match self {
            Self::Exciting { ground, .. } => *ground,
            Self::Decaying { ground, .. } => *ground,
        };
    }

    pub fn get_ground_state(&self) -> &S {
        return match self {
            Self::Exciting { ground, .. } => ground,
            Self::Decaying { ground, .. } => ground,
        };
    }

    pub fn excited_state(&self) -> S {
        return match self {
            Self::Exciting { ground: _, excited, .. } => *excited,
            Self::Decaying { ground: _, excited, .. } => *excited,
        };
    }

    pub fn get_excited_state(&self) -> &S {
        return match self {
            Self::Exciting { ground: _, excited, .. } => excited,
            Self::Decaying { ground: _, excited, .. } => excited,
        };
    }

    pub fn start_state(&self) -> S {
        return match self {
            Self::Exciting { ground, excited: _, .. } => *ground,
            Self::Decaying { ground: _, excited, .. } => *excited,
        };
    }

    pub fn get_start_state(&self) -> &S {
        return match self {
            Self::Exciting { ground, excited: _, .. } => ground,
            Self::Decaying { ground: _, excited, .. } => excited,
        };
    }

    pub fn end_state(&self) -> S {
        return match self {
            Self::Exciting { ground: _, excited, .. } => *excited,
            Self::Decaying { ground, excited: _, .. } => *ground,
        };
    }

    pub fn get_end_state(&self) -> &S {
        return match self {
            Self::Exciting { ground: _, excited, .. } => excited,
            Self::Decaying { ground, excited: _, .. } => ground,
        };
    }

    pub fn starts_with(&self, state: &S) -> bool {
        return match self {
            Self::Exciting { ground, .. }
                => ground == state,
            Self::Decaying { ground: _, excited, .. }
                => excited == state,
        };
    }

    pub fn exciting_starts_with(&self, state: &S) -> Option<bool> {
        return match self {
            Self::Exciting { ground, .. }
                => Some(ground == state),
            _ => None,
        };
    }

    pub fn decaying_starts_with(&self, state: &S) -> Option<bool> {
        return match self {
            Self::Decaying { ground: _, excited, .. }
                => Some(excited == state),
            _ => None,
        };
    }

    pub fn same_starts_with(&self, other: &Self) -> bool {
        return match (self, other) {
            (
                Self::Exciting { ground: g0, .. },
                Self::Exciting { ground: g1, .. },
            )
                => g0 == g1,
            (
                Self::Decaying { ground: _, excited: e0, .. },
                Self::Decaying { ground: _, excited: e1, .. },
            )
                => e0 == e1,
            _ => false,
        };
    }

    pub fn ends_with(&self, state: &S) -> bool {
        return match self {
            Self::Exciting { ground: _, excited, .. }
                => excited == state,
            Self::Decaying { ground, .. }
                => ground == state,
        };
    }

    pub fn exciting_ends_with(&self, state: &S) -> Option<bool> {
        return match self {
            Self::Exciting { ground: _, excited, .. }
                => Some(excited == state),
            _ => None,
        };
    }

    pub fn decaying_ends_with(&self, state: &S) -> Option<bool> {
        return match self {
            Self::Decaying { ground, .. }
                => Some(ground == state),
            _ => None,
        };
    }

    pub fn same_ends_with(&self, other: &Self) -> bool {
        return match (self, other) {
            (
                Self::Exciting { ground: _, excited: e0, .. },
                Self::Exciting { ground: _, excited: e1, .. },
            )
                => e0 == e1,
            (
                Self::Decaying { ground: g0, .. },
                Self::Decaying { ground: g1, .. },
            )
                => g0 == g1,
            _ => false,
        };
    }

    pub fn wavelength(&self) -> f64 {
        return match self {
            Self::Exciting { ground: _, excited: _, wavelength, .. }
                => *wavelength,
            Self::Decaying { ground: _, excited: _, wavelength, .. }
                => *wavelength,
        };
    }

    pub fn linewidth(&self) -> f64 {
        return match self {
            Self::Exciting {
                ground: _,
                excited: _,
                wavelength: _,
                linewidth,
                ..
            } => *linewidth,
            Self::Decaying {
                ground: _,
                excited: _,
                wavelength: _,
                linewidth,
                ..
            } => *linewidth,
        };
    }

    pub fn laser(&self) -> Option<Laser> {
        return match self {
            Self::Exciting {
                ground: _,
                excited: _,
                wavelength: _,
                linewidth: _,
                laser,
            } => Some(*laser),
            Self::Decaying { .. } => None,
        };
    }

    pub fn radiation(&self) -> Option<R> {
        return match self {
            Self::Exciting { .. } => None,
            Self::Decaying {
                ground: _,
                excited: _,
                wavelength: _,
                linewidth: _,
                radiation,
            } => Some(*radiation),
        };
    }
}

/// Holds all information from which transition probabilities can be calculated
/// and drives internal atomic state dynamics.
#[derive(Clone, Debug)]
pub struct StateGraph<S, T, R>
where
    S: State,
    T: Trap,
    R: RadiationPattern,
{
    transitions: Vec<Transition<S, R>>,
    traps: HashMap<S, T>,
}

impl<S, T, R> StateGraph<S, T, R>
where
    S: State,
    T: Trap,
    R: RadiationPattern,
{
    /// Create a new `StateGraph`. All duplicate transitions are counted
    /// separately. All states involed in transitions passed here must have an
    /// associated trap.
    pub fn new<I, J>(transitions: I, traps: J) -> AtomResult<Self>
    where
        I: IntoIterator<Item = Transition<S, R>>,
        J: IntoIterator<Item = (S, T)>,
    {
        let state_traps: HashMap<S, T> = traps.into_iter().collect();
        let transition_list: Vec<Transition<S, R>>
            = transitions.into_iter()
            .map(|t| {
                if !state_traps.contains_key(t.get_ground_state()) {
                    Err(AtomError::TrapUndefined(
                        format!("{:?}", t.get_ground_state())
                    ))
                } else if !state_traps.contains_key(t.get_excited_state()) {
                    Err(AtomError::TrapUndefined(
                        format!("{:?}", t.get_excited_state())
                    ))
                } else {
                    Ok(t)
                }
            })
            .collect::<AtomResult<Vec<Transition<S, R>>>>()?;
        return Ok(Self { transitions: transition_list, traps: state_traps });
    }

    /// Get the trapping potential for a state.
    pub fn get_trap(&self, state: &S) -> Option<&T> { self.traps.get(state) }

    /// Sample a photon interaction for a given transition.
    fn sample_photon_int<G>(
        &self,
        transition: &Transition<S, R>,
        q: &PhaseSpace,
        rng: &mut G,
    ) -> PhotonInteraction
    where G: Rng + ?Sized
    {
        return match transition {
            Transition::Exciting {
                ground,
                excited,
                wavelength: _,
                linewidth,
                laser,
            } => {
                let det: f64
                    = laser.detuning
                    + self.get_trap(excited).unwrap().light_shift(q.pos)
                    - self.get_trap(ground).unwrap().light_shift(q.pos);
                let s_eff: f64
                    = laser.saturation
                    * (
                        -2.0 * (laser.perp_dist(q.pos) / laser.radius).powi(2)
                    ).exp()
                    * ground.cg_sq(excited);
                let excite_time: f64
                    = rho_ee_inv(s_eff, det, *linewidth, rng.gen::<f64>());
                let excite_time_mean: f64
                    = rho_ee_mean_time(s_eff, det, *linewidth);
                let momentum_kick: ThreeVector = laser.momentum;
                PhotonInteraction::Absorption(
                    Absorption { excite_time, excite_time_mean, momentum_kick }
                )
            },
            Transition::Decaying {
                ground: _,
                excited: _,
                wavelength,
                linewidth,
                radiation,
            } => {
                let decay_time: f64
                    = -2.0 / linewidth * (1.0 - rng.gen::<f64>()).ln();
                let decay_time_mean: f64
                    = 2.0 / linewidth;
                let momentum_kick: ThreeVector
                    = radiation.sample_momentum_kick_rng(TAU / wavelength, rng);
                PhotonInteraction::Radiation(
                    Radiation { decay_time, decay_time_mean, momentum_kick }
                )
            },
        };
    }

    /// Sample a transition and corresponding photon interaction. Fails if the
    /// current state is completely dark and can't decay.
    pub fn next_state_checked_rng<G>(
        &self,
        current_state: &S,
        q: PhaseSpace,
        rng: &mut G,
    ) -> AtomResult<(S, PhotonInteraction)>
    where G: Rng + ?Sized
    {
        return self.transitions.iter()
            .filter_map(|t| {
                t.starts_with(current_state)
                    .then_some(
                        (t, self.sample_photon_int(t, &q, rng))
                    )
            })
            .min_by(|(_tl, pl), (_tr, pr)| {
                pl.transition_time().partial_cmp(&pr.transition_time())
                    .unwrap_or(std::cmp::Ordering::Less)
            })
            .ok_or(AtomError::DarkState)
            .map(|(t, p)| (t.end_state(), p));
    }

    /// Sample a transition and corresponding photon interaction. Fails if the
    /// current state is completely dark and can't decay.
    pub fn next_state_checked(&self, current_state: &S, q: PhaseSpace)
        -> AtomResult<(S, PhotonInteraction)>
    {
        let mut rng = rnd::thread_rng();
        return self.next_state_checked_rng(current_state, q, &mut rng);
    }

    /// Sample a photon interaction for a given transition, disregarding
    /// position and momentum.
    fn sample_photon_int_static<G>(
        &self,
        transition: &Transition<S, R>,
        rng: &mut G,
    ) -> PhotonInteraction
    where G: Rng + ?Sized
    {
        return match transition {
            Transition::Exciting {
                ground,
                excited,
                wavelength: _,
                linewidth,
                laser,
            } => {
                let det: f64
                    = laser.detuning
                    - self.get_trap(excited).unwrap().depth()
                    + self.get_trap(ground).unwrap().depth();
                let s_eff: f64
                    = laser.saturation
                    * ground.cg_sq(excited);
                let excite_time: f64
                    = rho_ee_inv(s_eff, det, *linewidth, rng.gen::<f64>());
                let excite_time_mean: f64
                    = rho_ee_mean_time(s_eff, det, *linewidth);
                let momentum_kick: ThreeVector = laser.momentum;
                PhotonInteraction::Absorption(
                    Absorption { excite_time, excite_time_mean, momentum_kick }
                )
            },
            Transition::Decaying {
                ground: _,
                excited: _,
                wavelength,
                linewidth,
                radiation,
            } => {
                let decay_time: f64
                    = -2.0 / linewidth * (1.0 - rng.gen::<f64>()).ln();
                let decay_time_mean: f64
                    = 2.0 / linewidth;
                let momentum_kick: ThreeVector
                    = radiation.sample_momentum_kick_rng(TAU / wavelength, rng);
                PhotonInteraction::Radiation(
                    Radiation { decay_time, decay_time_mean, momentum_kick }
                )
            },
        };
    }

    /// Sample a transition and corresponding photon interaction, disregarding
    /// position and momentum. Fails if the current state is completely dark and
    /// can't decay.
    pub fn next_state_static_checked_rng<G>(
        &self,
        current_state: &S,
        rng: &mut G,
    ) -> AtomResult<(S, PhotonInteraction)>
    where G: Rng + ?Sized
    {
        return self.transitions.iter()
            .filter_map(|t| {
                t.starts_with(current_state)
                    .then_some(
                        (t, self.sample_photon_int_static(t, rng))
                    )
            })
            .min_by(|(_tl, pl), (_tr, pr)| {
                pl.transition_time().partial_cmp(&pr.transition_time())
                    .unwrap_or(std::cmp::Ordering::Less)
            })
            .ok_or(AtomError::DarkState)
            .map(|(t, p)| (t.end_state(), p));
    }

    /// Sample a transition and corresponding photon interaction, disregarding
    /// position and momentum dependences. Fails if the current state is
    /// completely dark and can't decay.
    pub fn next_state_static_checked(&self, current_state: &S)
        -> AtomResult<(S, PhotonInteraction)>
    {
        let mut rng = rnd::thread_rng();
        return self.next_state_static_checked_rng(current_state, &mut rng);
    }

    /// Sample a transition and corresponding photon interaction, if possible.
    pub fn next_state_rng<G>(
        &self,
        current_state: &S,
        q: PhaseSpace,
        rng: &mut G
    ) -> (S, Option<PhotonInteraction>)
    where G: Rng + ?Sized
    {
        return self.next_state_checked_rng(current_state, q, rng)
            .map(|(state, photon)| (state, Some(photon)))
            .unwrap_or((*current_state, None));
    }

    /// Sample a transition and corresponding photon interaction, if possible.
    pub fn next_state(
        &self,
        current_state: &S,
        q: PhaseSpace,
    ) -> (S, Option<PhotonInteraction>)
    {
        return self.next_state_checked(current_state, q)
            .map(|(state, photon)| (state, Some(photon)))
            .unwrap_or((*current_state, None));
    }

    /// Sample a transition and corresponding photon interaction, if possible,
    /// disregarding position and momentum dependencies.
    pub fn next_state_static_rng<G>(&self, current_state: &S, rng: &mut G)
        -> (S, Option<PhotonInteraction>)
    where G: Rng + ?Sized
    {
        return self.next_state_static_checked_rng(current_state, rng)
            .map(|(state, photon)| (state, Some(photon)))
            .unwrap_or((*current_state, None));
    }

    /// Sample a transition and corresponding photon interaction, if possible,
    /// disregarding position and momentum dependencies.
    pub fn next_state_static(&self, current_state: &S)
        -> (S, Option<PhotonInteraction>)
    {
        return self.next_state_static_checked(current_state)
            .map(|(state, photon)| (state, Some(photon)))
            .unwrap_or((*current_state, None));
    }
}

/// Top-level interface for driving the Markov Chain.
///
/// [`AtomIter`][AtomIter] and [`AtomIterStatic`] provide convenient interfaces
/// for continuously driving the state machine through `Iterator` methods.
#[derive(Clone, Debug)]
pub struct Atom<S, T, R>
where
    S: State,
    T: Trap,
    R: RadiationPattern,
{
    pub state: S,
    state_graph: StateGraph<S, T, R>,
    pub mass: f64,
    pub temperature: f64,
}

impl<S, T, R> Atom<S, T, R>
where
    S: State,
    T: Trap,
    R: RadiationPattern,
{
    /// Verifies that `state`'s trap is defined in `state_graph`.
    pub fn new(
        state: S,
        state_graph: StateGraph<S, T, R>,
        mass: f64,
        temperature: f64,
    ) -> AtomResult<Self>
    {
        state_graph.get_trap(&state)
            .ok_or_else(|| AtomError::TrapUndefined(format!("{:?}", state)))?;
        return Ok(Self {
            state,
            state_graph,
            mass,
            temperature,
        });
    }

    /// Get the current state.
    pub fn state(&self) -> S { self.state }

    /// Get a reference to the current state.
    pub fn get_state(&self) -> &S { &self.state }

    /// Get the trapping potential for the current state.
    pub fn get_trap(&self) -> &T {
        return self.state_graph.get_trap(&self.state).unwrap();
    }

    /// Get the trapping potential for a state.
    pub fn get_trap_for(&self, state: &S) -> Option<&T> {
        return self.state_graph.get_trap(state);
    }

    /// Creates an `Iterator` interface for the Markov chain.
    pub fn state_iter(
        init: S,
        state_graph: StateGraph<S, T, R>,
        mass: f64,
        temperature: f64,
    ) -> AtomResult<AtomIter<S, T, R>>
    {
        let mut rng = rnd::thread_rng();
        let q: PhaseSpace
            = state_graph.get_trap(&init)
            .ok_or_else(|| AtomError::TrapUndefined(format!("{:?}", init)))?
            .sample_phasespace_rng(mass, temperature, &mut rng);
        let atom = Self::new(init, state_graph, mass, temperature)?;
        return Ok(AtomIter { atom, q, rng });
    }

    /// Creates an `Iterator` interface for the position-independent Markov
    /// chain.
    pub fn state_iter_static(
        init: S,
        state_graph: StateGraph<S, T, R>,
        mass: f64,
        temperature: f64,
    ) -> AtomResult<AtomIterStatic<S, T, R>>
    {
        let rng = rnd::thread_rng();
        let atom = Self::new(init, state_graph, mass, temperature)?;
        return Ok(AtomIterStatic { atom, rng });
    }

    /// Clones `self` into an `Iterator` interface for the Markov chain.
    pub fn to_state_iter(&self) -> AtomIter<S, T, R> {
        return Self::state_iter(
            self.state,
            self.state_graph.clone(),
            self.mass,
            self.temperature,
        ).unwrap();
    }

    /// Clones `self` into an `Iterator` interface for the position-independent
    /// Markov chain.
    pub fn to_state_iter_static(&self) -> AtomIterStatic<S, T, R> {
        return Self::state_iter_static(
            self.state,
            self.state_graph.clone(),
            self.mass,
            self.temperature,
        ).unwrap();
    }

    /// Converts `self` to an `Iterator` interface for the Markov chain.
    pub fn into_state_iter(self) -> AtomIter<S, T, R> {
        return Self::state_iter(
            self.state,
            self.state_graph,
            self.mass,
            self.temperature,
        ).unwrap();
    }

    /// Converts `self` to an `Iterator` interface for the position-independent
    /// Markov chain.
    pub fn into_state_iter_static(self) -> AtomIterStatic<S, T, R> {
        return Self::state_iter_static(
            self.state,
            self.state_graph,
            self.mass,
            self.temperature,
        ).unwrap();
    }

    /// Make a transition to a new state and return the corresponding photon
    /// interaction, if possible.
    pub fn next_state_checked_rng<G>(&mut self, q: PhaseSpace, rng: &mut G)
        -> AtomResult<(S, PhotonInteraction)>
    where G: Rng + ?Sized
    {
        let (state, photon): (S, PhotonInteraction)
            = self.state_graph.next_state_checked_rng(&self.state, q, rng)?;
        self.state = state;
        return Ok((state, photon));
    }

    /// Make a transition to a new state and return the corresponding photon
    /// interaction, if possible.
    pub fn next_state_checked(&mut self, q: PhaseSpace)
        -> AtomResult<(S, PhotonInteraction)>
    {
        let mut rng = rnd::thread_rng();
        return self.next_state_checked_rng(q, &mut rng);
    }

    /// Make a transition to a new state and return the corresponding photon
    /// interaction, if possible. If a transition did not occur, the photon
    /// interaction will be `None`.
    pub fn next_state_rng<G>(&mut self, q: PhaseSpace, rng: &mut G)
        -> (S, Option<PhotonInteraction>)
    where G: Rng + ?Sized
    {
        let (state, maybe_photon): (S, Option<PhotonInteraction>)
            = self.state_graph.next_state_rng(&self.state, q, rng);
        self.state = state;
        return (state, maybe_photon);
    }

    /// Make a transition to a new state and return the corresponding photon
    /// interaction, if possible. If a transition did not occur, the photon
    /// interaction will be `None`.
    pub fn next_state(&mut self, q: PhaseSpace)
        -> (S, Option<PhotonInteraction>)
    {
        let mut rng = rnd::thread_rng();
        return self.next_state_rng(q, &mut rng);
    }

    /// Make a transition to a new state and return the corresponding photon
    /// interaction, if possible, disregarding position and momentum.
    pub fn next_state_static_checked_rng<G>(&mut self, rng: &mut G)
        -> AtomResult<(S, PhotonInteraction)>
    where G: Rng + ?Sized
    {
        let (state, photon): (S, PhotonInteraction)
            = self.state_graph.next_state_static_checked_rng(&self.state, rng)?;
        self.state = state;
        return Ok((state, photon));
    }

    /// Make a transition to a new state and return the corresponding photon
    /// interaction, if possible, disregarding position and momentum.
    pub fn next_state_static_checked(&mut self)
        -> AtomResult<(S, PhotonInteraction)>
    {
        let mut rng = rnd::thread_rng();
        return self.next_state_static_checked_rng(&mut rng);
    }

    /// Make a transition to a new state and return the corresponding photon
    /// interaction, if possible, disregarding position and momentum. If a
    /// transition did not occur, the photon interaction will be `None`.
    pub fn next_state_static_rng<G>(&mut self, rng: &mut G)
        -> (S, Option<PhotonInteraction>)
    where G: Rng + ?Sized
    {
        let (state, maybe_photon): (S, Option<PhotonInteraction>)
            = self.state_graph.next_state_static_rng(&self.state, rng);
        self.state = state;
        return (state, maybe_photon);
    }

    /// Make a transition to a new state and return the corresponding photon
    /// interaction, if possible, disregarding position and momentum. If a
    /// transition did not occur, the photon interaction will be `None`.
    pub fn next_state_static(&mut self)
        -> (S, Option<PhotonInteraction>)
    {
        let mut rng = rnd::thread_rng();
        return self.next_state_static_rng(&mut rng);
    }
}

/// `Iterator` interface for driving the Markov chain.
///
/// Can be clonelessly converted back to a regular [`Atom`][Atom].
#[derive(Clone)]
pub struct AtomIter<S, T, R>
where
    S: State,
    T: Trap,
    R: RadiationPattern,
{
    atom: Atom<S, T, R>,
    /// Current phase-space vector of the atom.
    pub q: PhaseSpace,
    rng: rnd::ThreadRng,
}

impl<S, T, R> AtomIter<S, T, R>
where
    S: State,
    T: Trap,
    R: RadiationPattern,
{
    /// Create a new `AtomIter`. The preferred method is to use
    /// `Atom::state_iter` or one of the other similar methods.
    pub fn new(atom: Atom<S, T, R>, q: PhaseSpace, rng: rnd::ThreadRng)
        -> Self
    {
        return Self { atom, q, rng };
    }

    /// Return a reference to the atom in its current state.
    pub fn get_atom(&self) -> &Atom<S, T, R> { &self.atom }

    /// Get the trapping potential for the current state.
    pub fn get_trap(&self) -> &T { self.atom.get_trap() }

    /// Get the trapping potential for a state.
    pub fn get_trap_for(&self, state: &S) -> Option<&T> {
        return self.atom.get_trap_for(state);
    }

    /// Set `self.q` to a new value.
    pub fn set_q(&mut self, q: PhaseSpace) { self.q = q; }

    /// Convert `self` back into a normal [`Atom`][Atom], also returning the
    /// atom's final phase-space vector.
    pub fn dump(self) -> (Atom<S, T, R>, PhaseSpace) { (self.atom, self.q) }
}

impl<S, T, R> Iterator for AtomIter<S, T, R>
where
    S: State,
    T: Trap,
    R: RadiationPattern,
{
    type Item = (S, PhotonInteraction);

    fn next(&mut self) -> Option<Self::Item> {
        return match self.atom.next_state_rng(self.q, &mut self.rng) {
            (s, Some(p)) => Some((s, p)),
            (_, None) => None,
        };
    }
}

/// `Iterator` interface for driving the position-independent Markov chain.
///
/// Can be clonelessly converted back to a regular [`Atom`][Atom].
#[derive(Clone)]
pub struct AtomIterStatic<S, T, R>
where
    S: State,
    T: Trap,
    R: RadiationPattern,
{
    atom: Atom<S, T, R>,
    rng: rnd::ThreadRng,
}

impl<S, T, R> AtomIterStatic<S, T, R>
where
    S: State,
    T: Trap,
    R: RadiationPattern,
{
    /// Create a new `AtomIterStatic`. The preferred method is to use
    /// `Atom::state_iter_static` or one of the other similar methods.
    pub fn new(atom: Atom<S, T, R>, rng: rnd::ThreadRng) -> Self {
        return Self { atom, rng };
    }

    /// Return a reference to the atom in its current state.
    pub fn get_atom(&self) -> &Atom<S, T, R> { &self.atom }

    /// Get the trapping potential for the current state.
    pub fn get_trap(&self) -> &T { self.atom.get_trap() }

    /// Get the trapping potential for a state.
    pub fn get_trap_for(&self, state: &S) -> Option<&T> {
        return self.atom.get_trap_for(state);
    }

    /// Convert `self` back into a normal [`Atom`][Atom].
    pub fn dump(self) -> Atom<S, T, R> { self.atom }
}

impl<S, T, R> Iterator for AtomIterStatic<S, T, R>
where
    S: State,
    T: Trap,
    R: RadiationPattern,
{
    type Item = (S, PhotonInteraction);

    fn next(&mut self) -> Option<Self::Item> {
        return match self.atom.next_state_static_rng(&mut self.rng) {
            (s, Some(p)) => Some((s, p)),
            (_, None) => None,
        };
    }
}