from __future__ import annotations
import numpy as np
from sympy.physics.quantum.cg import CG
from sympy.physics.wigner import wigner_3j
from itertools import product
from enum import Enum

"""
Functions to generate Hamiltonians for the 171Yb clock-Rydberg manifold:

                                        r1+ ======
                            r0+ ======
                r0- ======
    r1- ======


                g0  ======
                            g1  ======

Provides additional functionality for time-dependence as well as couplings for
sigma+, sigma-, and pi polarized light with adjustable polarization impurity.
Implicitly applies the rotating wave approximation. All energies are to be
taken in s^-1.
"""

class Basis:
    def __init__(self, states: dict[str, float]):
        self.states = {
            S: np.array(
                [1 if s == S else 0 for s in states.keys()],
                dtype=np.complex128
            )
            for S in states.keys()
        }
        self.energies = states
        self.labels = list(states.keys())
        self.indices = {l: i for i, l in enumerate(self.labels)}

    def __getitem__(self, key: str | int):
        if isinstance(key, int):
            return self.labels[key]
        elif isinstance(key, str):
            return self.states[key]
        else:
            raise KeyError

    def __iter__(self):
        return iter(self.items())

    def items(self):
        return [(l, self.states[l], self.energies[l]) for l in self.labels]

    def keys(self):
        return self.labels

    def values(self):
        return [(self.states[l], self.energies[l]) for l in self.labels]

    def __len__(self):
        return len(self.labels)

    def label(self, index: int):
        return self.labels[index]

    def index(self, label: str):
        return self.indices[label]

    def energy(self, k: str | int):
        if isinstance(k, str):
            return self.energies[k]
        elif isinstance(k, int):
            return self.energies[self.labels[k]]
        else:
            raise KeyError

    def to_multiatom(self, n: int, delim: str=","):
        return Basis({
            delim.join(S): sum(self.energies[s] for s in S)
            for S in product(*[self.labels for k in range(n)])
        })

trans_w3j = lambda F0, mF0, F1, mF1: \
        float(
            abs(
                (-1)**(F1 - 1 + mF0) * np.sqrt(2 * F0 + 1)
                * wigner_3j(F1, 1, F0, mF1, mF0 - mF1, -mF0).doit().evalf()
            )
        )
w3j = {
    ("g1", "r1+"): trans_w3j(1/2, +1/2, 3/2, +3/2),
    ("r1+", "g1"): trans_w3j(1/2, +1/2, 3/2, +3/2),

    ("g1", "r0+"): trans_w3j(1/2, +1/2, 3/2, +1/2),
    ("r0+", "g1"): trans_w3j(1/2, +1/2, 3/2, +1/2),

    ("g1", "r0-"): trans_w3j(1/2, +1/2, 3/2, -1/2),
    ("r0-", "g1"): trans_w3j(1/2, +1/2, 3/2, -1/2),

    ("g0", "r0+"): trans_w3j(1/2, -1/2, 3/2, +1/2),
    ("r0+", "g0"): trans_w3j(1/2, -1/2, 3/2, +1/2),

    ("g0", "r0-"): trans_w3j(1/2, -1/2, 3/2, -1/2),
    ("r0-", "g0"): trans_w3j(1/2, -1/2, 3/2, -1/2),

    ("g0", "r1-"): trans_w3j(1/2, -1/2, 3/2, -3/2),
    ("r1-", "g0"): trans_w3j(1/2, -1/2, 3/2, -3/2),
}

class PulseType(Enum):
    Sigma_p = 0
    Sigma_m = 1
    Pi = 2

def trapz_progressive(y, dx=1.0):
    acc = 0.0
    I = np.zeros(y.shape, dtype=y.dtype)
    for k in range(1, y.shape[0]):
        acc += dx * (y[k] + y[k - 1]) / 2.0
        I[k] = acc
    return I

def trapz_prog(y, dx=1.0):
    return np.append(0, np.cumsum(dx * (y[:-1] + y[1:]) / 2.0))

def _gen_time_dep(t, integrate: list[float | np.ndarray],
        extrude: list[float]) -> list[np.ndarray]:
    assert all(np.shape(X) in {t.shape, ()} for X in integrate)
    dt = abs(t[1] - t[0])
    ones = np.ones(t.shape)
    return [
        trapz_prog(X, dt) if len(np.shape(X)) > 0 else X * t
        for X in integrate
    ] + [X * ones for X in extrude]

def H_pulse(pulse_type: PulseType, basis, w, W, chi, t, phi=0) \
        -> np.ndarray:
    """
    Thin wrapper around `H_sigma_p`, `H_sigma_m`, or `H_pi`, depending on the
    value of `pulse_type`.
    """
    if pulse_type == PulseType.Sigma_p:
        return H_sigma_p(basis, w, W, chi, t, phi)
    elif pulse_type == PulseType.Sigma_m:
        return H_sigma_m(basis, w, W, chi, t, phi)
    elif pulse_type == PulseType.Pi:
        return H_pi(basis, w, W, chi, t, phi)
    else:
        raise Exception

def H_sigma_p(basis, w, W, chi, t, phi=0) -> np.ndarray:
    """
    Construct the Hamiltonian with only off-diagonal elements for a sigma+
    drive. Requires time dependence.

    Parameters
    ----------
    basis : Basis
        Single-atom basis.
    w : numpy.ndarray[dim=1, dtype=numpy.complex128]
        Drive frequency relative to the difference between the averages of the
        ground and Rydberg manifolds.
    W : numpy.ndarray[dim=1, dtype=numpy.complex128]
        Rabi frequency of the principal drive. Defined to be the frequency of
        oscillation in populations, not amplitudes.
    chi : numpy.ndarray[dim=1, dtype=numpy.complex128]
        Polarization impurity.
    phi : float (optional)
        Phase shift on the drive. Defaults to zero.

    Returns
    -------
    H : numpy.ndarray[dim=2, dtype=numpy.complex128]
        Hamiltonian with only off-diagonal elements for sigma+ light.
    """
    [w_, E_g0, E_g1, E_r1m, E_r0m, E_r0p, E_r1p, W_, chi_] \
            = _gen_time_dep(
                t,
                integrate=[
                    w,
                    basis.energy("g0"),
                    basis.energy("g1"),
                    basis.energy("r1-"),
                    basis.energy("r0-"),
                    basis.energy("r0+"),
                    basis.energy("r1+"),
                ],
                extrude=[W, chi]
            )
    H = np.zeros((len(basis), len(basis), t.shape[0]), dtype=np.complex128)
    for (i, s1), (j, s2) \
            in product(enumerate(basis.labels), enumerate(basis.labels)):
        gr1 = ("g1", "r1+")
        # gr0 = ("g0", "r0+")
        if (s1, s2) == ("g1", "r1+"):
            w0_ = E_r1p - E_g1
            H[i, j, :] = (W_ / 2) * np.sqrt(1 - chi_) \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(1 - chi_) \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (gr := (s1, s2)) == ("g1", "r0+"):
            w0_ = E_r0p - E_g1
            H[i, j, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (gr := (s1, s2)) == ("g1", "r0-"):
            w0_ = E_r0m - E_g1
            H[i, j, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (s1, s2) == ("g0", "r0+"):
            w0_ = E_r0p - E_g0
            H[i, j, :] = (W_ / 2) * np.sqrt(1 - chi_) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(1 - chi_) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (gr := (s1, s2)) == ("g0", "r0-"):
            w0_ = E_r0m - E_g0
            H[i, j, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (gr := (s1, s2)) == ("g0", "r1-"):
            w0_ = E_r1m - E_g0
            H[i, j, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
    return H

def H_sigma_m(basis, w, d, D, W, chi, t, phi=0) -> np.ndarray:
    """
    Construct the Hamiltonian with only off-diagonal elements for a sigma-
    drive. Requires time dependence.

    Parameters
    ----------
    basis : Basis
        Single-atom basis.
    w : numpy.ndarray[dim=1, dtype=numpy.complex128]
        Drive frequency relative to the difference between the averages of the
        ground and Rydberg manifolds.
    d : numpy.ndarray[dim=1, dtype=numpy.complex128]
        Ground state splitting.
    D : numpy.ndarray[dim=1, dtype=numpy.complex128]
        Rydberg state splitting.
    W : numpy.ndarray[dim=1, dtype=numpy.complex128]
        Rabi frequency of the principal drive. Defined to be the frequency of
        oscillation in populations, not amplitudes.
    chi : numpy.ndarray[dim=1, dtype=numpy.complex128]
        Polarization impurity.
    phi : float (optional)
        Phase shift on the drive. Defaults to zero.

    Returns
    -------
    H : numpy.ndarray[dim=2, dtype=numpy.complex128]
        Hamiltonian with only off-diagonal elements for sigma- light.
    """
    [w_, E_g0, E_g1, E_r1m, E_r0m, E_r0p, E_r1p, W_, chi_] \
            = _gen_time_dep(
                t,
                integrate=[
                    w,
                    basis.energy("g0"),
                    basis.energy("g1"),
                    basis.energy("r1-"),
                    basis.energy("r0-"),
                    basis.energy("r0+"),
                    basis.energy("r1+"),
                ],
                extrude=[W, chi]
            )
    H = np.zeros((len(basis), len(basis), t.shape[0]), dtype=np.complex128)
    for (i, s1), (j, s2) \
            in product(enumerate(basis.labels), enumerate(basis.labels)):
        gr1 = ("g1", "r0-")
        # gr0 = ("g0", "r1-")
        if (s1, s2) == ("g1", "r0-"):
            w0_ = E_r0m - E_g1
            H[i, j, :] = (W_ / 2) * np.sqrt(1 - chi_) \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(1 - chi_) \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (gr := (s1, s2)) == ("g1", "r0+"):
            w0_ = E_r0p - E_g1
            H[i, j, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (gr := (s1, s2)) == ("g1", "r1+"):
            w0_ = E_r1p - E_g1
            H[i, j, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (s1, s2) == ("g0", "r1-"):
            w0_ = E_r1m - E_g0
            H[i, j, :] = (W_ / 2) * np.sqrt(1 - chi_) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(1 - chi_) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (gr := (s1, s2)) == ("g0", "r0-"):
            w0_ = E_r0m - E_g0
            H[i, j, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (s1, s2) == ("g0", "r0+"):
            w0_ = E_r0p - E_g0
            H[i, j, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
    return H

def H_pi(basis, w, d, D, W, chi, t, phi=0) -> np.ndarray:
    """
    Construct the Hamiltonian with only off-diagonal elements for a pi-polarized
    drive. Requires time dependence.

    Parameters
    ----------
    basis : Basis
        Single-atom basis.
    w : numpy.ndarray[dim=1, dtype=numpy.complex128]
        Drive frequency relative to the difference between the average
        frequencies of the ground and Rydberg manifolds.
    d : numpy.ndarray[dim=1, dtype=numpy.complex128]
        Ground state splitting.
    D : numpy.ndarray[dim=1, dtype=numpy.complex128]
        Rydberg state splitting.
    W : numpy.ndarray[dim=1, dtype=numpy.complex128]
        Rabi frequency of the principal drive. Defined to be the frequency of
        oscillation in populations, not amplitudes.
    chi : numpy.ndarray[dim=1, dtype=numpy.complex128]
        Polarization impurity.
    phi : float (optional)
        Phase shift on the drive. Defaults to zero.

    Returns
    -------
    H : numpy.ndarray[dim=2, dtype=numpy.complex128]
        Hamiltonian with only off-diagonal elements for pi-polarized light.
    """
    [w_, E_g0, E_g1, E_r1m, E_r0m, E_r0p, E_r1p, W_, chi_] \
            = _gen_time_dep(
                t,
                integrate=[
                    w,
                    basis.energy("g0"),
                    basis.energy("g1"),
                    basis.energy("r1-"),
                    basis.energy("r0-"),
                    basis.energy("r0+"),
                    basis.energy("r1+"),
                ],
                extrude=[W, chi]
            )
    H = np.zeros((len(basis), len(basis), t.shape[0]), dtype=np.complex128)
    for (i, s1), (j, s2) \
            in product(enumerate(basis.labels), enumerate(basis.labels)):
        gr1 = ("g1", "r0+")
        # gr0 = ("g0", "r0-")
        if (s1, s2) == ("g1", "r0+"):
            w0_ = E_r0p - E_g1
            H[i, j, :] = (W_ / 2) * np.sqrt(1 - chi_) \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(1 - chi_) \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (gr := (s1, s2)) == ("g1", "r1+"):
            w0_ = E_r1p - E_g1
            H[i, j, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (gr := (s1, s2)) == ("g1", "r0-"):
            w0_ = E_r0m - E_g1
            H[i, j, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (s1, s2) == ("g0", "r0-"):
            w0_ = E_r0m - E_g0
            H[i, j, :] = (W_ / 2) * np.sqrt(1 - chi_) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(1 - chi_) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (gr := (s1, s2)) == ("g0", "r0+"):
            w0_ = E_r0p - E_g0
            H[i, j, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
        elif (gr := (s1, s2)) == ("g0", "r1-"):
            w0_ = E_r1m - E_g0
            H[i, j, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(+1j * (w_ - w0_ + phi))
            H[j, i, :] = (W_ / 2) * np.sqrt(chi_ / 2) * w3j[gr] / w3j[gr1] \
                    * np.exp(-1j * (w_ - w0_ + phi))
    return H

def gen_multiatom(basis, n, H, U, basis_n=None) -> np.ndarray:
    """
    Generate the n-atom Hamiltonian from a single-atom Hamilonian (assume all
    atoms are subject to the same drive).

    Parameters
    ----------
    basis : Basis
        Single-atom basis.
    n : int
        Number of atoms.
    H : numpy.ndarray[shape=(_, _, k), dtype=numpy.complex128]
        Single-atom Hamiltonian. Can be time-dependent, with the third index
        corresponding to time, which requires `U` be time-dependent as well.
    U : float | numpy.ndarray[shape=(k,), dtype=numpy.complex128]
        Rydberg interaction strength. If time-dependent (passed as an array),
        then `H` is required to be time-dependent as well.

    Returns
    -------
    B_n : Basis
        n-atom basis.
    H_n : numpy.ndarray[shape=(_, _, k), dtype=numpy.complex128]
        n-atom Hamiltonian.
    """
    time_dep = all([len(H.shape) == 3, isinstance(U, np.ndarray)])
    time_indep = all([len(H.shape) == 2, isinstance(U, float)])
    assert time_dep ^ time_indep
    B_n = basis.to_multiatom(n) if basis_n is None else basis_n
    if not time_dep:
        H_n = sum(
            np.kron(
                np.kron(
                    np.eye(len(basis)**k, dtype=np.complex128),
                    H
                ),
                np.eye(len(basis)**(n - k - 1), dtype=np.complex128)
            )
            for k in range(n)
        )
        for i, s in enumerate(B_n.labels):
            if (k := s.count("r")) > 1:
                H_n[i, i] += (k - 1) * U
        return B_n, H_n
    else:
        assert H.shape[-1] == U.shape[0]
        H_n = np.array([
            gen_multiatom(basis, n, H[:, :, i], U[i], B_n)[1].T
            for i in range(U.shape[0])
        ], dtype=np.complex128).T
        return B_n, H_n

def gen_chain(basis, n, H, U, basis_n=None) -> np.ndarray:
    """
    Generate the n-atom Hamiltonian for a 1D Rydberg chain from a single-atom
    Hamilonian (assume all atoms are subject to the same drive).

    Parameters
    ----------
    basis : Basis
        Single-atom basis.
    n : int
        Number of atoms.
    H : numpy.ndarray[shape=(_, _, k), dtype=numpy.complex128]
        Single-atom Hamiltonian. Can be time-dependent, with the third index
        corresponding to time, which requires `U` be time-dependent as well.
    U : float | numpy.ndarray[shape=(k,), dtype=numpy.complex128]
        Rydberg interaction strength. If time-dependent (passed as an array),
        then `H` is required to be time-dependent as well.

    Returns
    -------
    B_n : Basis
        n-atom basis.
    H_n : numpy.ndarray[shape=(_, _, k), dtype=numpy.complex128]
        n-atom Hamiltonian.
    """
    time_dep = all([len(H.shape) == 3, isinstance(U, np.ndarray)])
    time_indep = all([len(H.shape) == 2, isinstance(U, float)])
    assert time_dep ^ time_indep
    B_n = basis.to_multiatom(n) if basis_n is None else basis_n
    if not time_dep:
        H_n = sum(
            np.kron(
                np.kron(
                    np.eye(len(basis)**k, dtype=np.complex128),
                    H
                ),
                np.eye(len(basis)**(n - k - 1), dtype=np.complex128)
            )
            for k in range(n)
        )
        for i, s in enumerate(B_n.labels):
            if (k := s.count("r")) > 1:
                H_n[i, i] += (k - 1) * U
        return B_n, H_n
        for i, s in enumerate(B_n.labels):
            indiv = s.split(",")
            u = sum(
                sum(
                    U / (i - j)**6
                    for j, aj in enumerate(indiv)
                    if "r" in aj and i != j
                )
                for i, ai in enumerate(indiv)
                if "r" in ai
            ) / 2
            H_n[i, i] += u
        return B_n, H_n
    else:
        assert H.shape[-1] == U.shape[0]
        H_n = np.array([
            gen_chain(basis, n, H[:, :, i], U[i], B_n)[1].T
            for i in range(U.shape[0])
        ], dtype=np.complex128).T
        return B_n, H_n

def gen_multiatom_per(basis, HH, U, basis_n=None):
    """
    Generate the n-atom Hamiltonian from a collection of single-atom
    Hamiltonians.

    Parameters
    ----------
    basis : Basis
        Single-atom basis.
    HH : list[numpy.ndarray[shape=(_, _, k), dtype=numpy.complex128]]
        List-like of single-atom Hamiltonians. Can be time-dependent, with the
        third index corresponding to time, which requires `U` be time-dependent
        as well.
    U : float | numpy.ndarray[shape=(k,), dtype=numpy.complex128]
        Rydberg interaction strength. If time-dependent (passed as an array),
        then `H` is required to be time-dependent as well.

    Returns
    -------
    B_n : Basis
        n-atom basis.
    H_n : numpy.ndarray[shape=(_, _, k), dtype=numpy.complex128]
        n-atom Hamiltonian.
    """
    time_dep = all([
        all([len(H.shape) == 3 for H in HH]),
        isinstance(U, np.ndarray)
    ])
    time_indep = all([
        all([len(H.shape) == 2 for H in HH]),
        isinstance(U, float)
    ])
    assert time_dep ^ time_indep
    n = len(HH)
    B_n = basis.to_multiatom(n) if basis_n is None else basis_n
    if not time_dep:
        H_n = sum(
            np.kron(
                np.kron(
                    np.eye(len(basis)**k, dtype=np.complex128),
                    H
                ),
                np.eye(len(basis)**(n - k - 1), dtype=np.complex128)
            )
            for k, H in enumerate(HH)
        )
        for i, s in enumerate(B_n.labels):
            if (k := s.count("r")) > 1:
                H_n[i, i] += (k - 1) * U
        return B_n, H_n
    else:
        assert all(H.shape[-1] == U.shape[0] for H in HH)
        H_n = np.array([
            gen_multiatom_per(
                basis, [H[:, :, i] for H in HH], U[i], B_n)[1].T
            for i in range(U.shape[0])
        ], dtype=np.complex128).T
        return B_n, H_n

def gen_chain_per(basis, HH, U, basis_n=None):
    """
    Generate the n-atom Hamiltonian for a 1D Rydberg chain from a collection of
    single-atom Hamiltonians.

    Parameters
    ----------
    basis : Basis
        Single-atom basis.
    HH : list[numpy.ndarray[shape=(_, _, k), dtype=numpy.complex128]]
        List-like of single-atom Hamiltonians. Can be time-dependent, with the
        third index corresponding to time, which requires `U` be time-dependent
        as well.
    U : float | numpy.ndarray[shape=(k,), dtype=numpy.complex128]
        Rydberg interaction strength. If time-dependent (passed as an array),
        then `H` is required to be time-dependent as well.

    Returns
    -------
    B_n : Basis
        n-atom basis.
    H_n : numpy.ndarray[shape=(_, _, k), dtype=numpy.complex128]
        n-atom Hamiltonian.
    """
    time_dep = all([
        all([len(H.shape) == 3 for H in HH]),
        isinstance(U, np.ndarray)
    ])
    time_indep = all([
        all([len(H.shape) == 2 for H in HH]),
        isinstance(U, float)
    ])
    assert time_dep ^ time_indep
    n = len(HH)
    B_n = basis.to_multiatom(n) if basis_n is None else basis_n
    if not time_dep:
        H_n = sum(
            np.kron(
                np.kron(
                    np.eye(len(basis)**k, dtype=np.complex128),
                    H
                ),
                np.eye(len(basis)**(n - k - 1), dtype=np.complex128)
            )
            for k, H in enumerate(HH)
        )
        for i, s in enumerate(B_n.labels):
            if (k := s.count("r")) > 1:
                H_n[i, i] += (k - 1) * U
        return B_n, H_n
        for i, s in enumerate(B_n.labels):
            indiv = s.split(",")
            u = sum(
                sum(
                    U / (i - j)**6
                    for j, aj in enumerate(indiv)
                    if "r" in aj and i != j
                )
                for i, ai in enumerate(indiv)
                if "r" in ai
            ) / 2
            H_n[i, i] += u
        return B_n, H_n
    else:
        assert all(H.shape[-1] == U.shape[0] for H in HH)
        H_n = np.array([
            gen_chain_per(
                basis, [H[:, :, i] for H in HH], U[i], B_n)[1].T
            for i in range(U.shape[0])
        ], dtype=np.complex128).T
        return B_n, H_n

def enumerators(basis, n=1, enumeration=None):
    """
    Generate "enumeration" operators that assign numbers to each of the
    single-atom states in `basis` for a single atom in the `n`-atom basis.

    Parameters
    ----------
    basis : Basis
        Single-atom basis.
    n : int (optional)
        Number of atoms.
    enumeration : list[int] (optional)
        Customize the enumeration of the single-atom states. If left
        unspecified, use [0, ..., B - 1], where B is the length of `basis`.

    Returns
    -------
    ops : list[numpy.ndarray[dim=2, dtype=numpy.float64]]
        Enumeration operators for single atoms, written in the `n`-atom basis.
    """
    assert enumeration is None or len(enumeration) == len(basis)
    enum = np.arange(len(basis), dtype=np.float64) if enumeration is None \
            else enumeration
    E = np.diag(enum)
    ops = list()
    for k in range(n):
        e = np.eye(len(basis)**k, dtype=np.float64)
        e = np.kron(e, E)
        e = np.kron(e, np.eye(len(basis)**(n - k - 1), dtype=np.float64))
        ops.append(e)
    return ops

def enumerator_avg(basis, n=1, enumeration=None):
    """
    Generate an "average enumeration" operator that assigns a number 0 ... B - 1
    to each of the single-atom states in `basis` for `n` atoms (where B is the
    length of `basis`) and gives the average such number in a many-bady state.

    Parameters
    ----------
    basis : Basis
        Single-atom basis.
    n : int
        Number of atoms.
    enumeration : list[int] (optional)
        Customize the enumeration of the single-atom states. If left
        unspecified, use [0, ..., B - 1], where B is the length of `basis`.

    Returns
    -------
    E_n : numpy.ndarray[dim=2, dtype=numpy.float64]
        `n`-atom average enumeration operator.
    """
    return sum(enumeration_ops(basis, n)) / n