from AWG import *
import time


class Waveform:
    def __init__(self, cf: int, df: int, n: int, sample_rate: int):
        """
        helper class to store basic waveform information.
        :param cf: center frequency tone of tweezer array.
        :param df: differential frequency between neighboring tweezers.
        :param n: number of tweezers to the left/right of center frequency tone, total number of tweezer is 2n+1.
        :param sample_rate: sampling rate of the AWG to generate correct number of samples.
        """
        # define some useful numbers
        scale = 2**11  # Mingkun uses 2^11, caltech uses 2^15, maybe this is scaling up a float to int?
        num_tz = 2*n + 1  # total number of tweezers to be generated
        max_amp = scale / np.sqrt(num_tz)  # again, saw this from multiple sources, not sure why

        # self.amplitude = max_amp * np.ones(num_tz)
        self.amplitude: np.ndarray = max_amp  # uniform amplitude for now, this will eventually be tweaked finely with experiments
        self.omega: np.ndarray = 2*np.pi * np.linspace(cf - n*df, cf + n*df, num_tz)  # frequency tones
        self.phi: np.ndarray = 2*np.pi * np.random.rand(num_tz)  # random initial phases from 0-2pi, also will be tweaked finely with experiments
        self.sample_rate: int = sample_rate
        # self.debug = {
        #     "mat1": 0,
        #     "mat2": 0,
        #     "mat3": 0
        # }


def create_static_array(wfm: Waveform, sample_len: int) -> np.ndarray:
    """
    create a static-array-generating waveform with user set number of samples
    :param wfm: waveform object already initialized with basic parameters.
    :param sample_len: total number of samples to generate, must be multiples of 512. Note that more sample != higher resolution.
    :return: returns a 1D array with static-array-generating waveform.
    """
    # construct time axis, t_total(s) = sample_len / sample_rate, dt = t_total / sample_len
    t = np.arange(sample_len) / wfm.sample_rate

    # calculate individual sin waves, sig_mat[i] corresponds to data for ith tweezer
    sin_mat = wfm.amplitude * np.sin(np.outer(wfm.omega,t) + np.expand_dims(wfm.phi, axis=1))  # shape=(number of tweezers x sample_len)

    # sum up all rows to get final signal
    return np.sum(sin_mat, axis=0)


def create_path_table(wfm: Waveform) -> any:
    """
    create a dim-3 look up table where the table[i,j] contains a sine wave to move tweezer i to tweezer j
    :param wfm: waveform object already initialized with basic parameters.
    :return: dim-3 ndarray
    """
    # setup basic variables
    twopi = 2*np.pi
    vmax = KILO(20) * MEGA(1)  # convert units, 20 kHz/us -> 20e3 * 1e6 Hz/s
    dw_max = wfm.omega[-1] - wfm.omega[0]  # Longest move in frequency
    t_max = 2 * dw_max / vmax  # Longest move sets the maximum moving time
    a_max = -vmax * 2 / t_max  # maximum acceleration, negative sign because of magic
    sample_len_max = int(np.ceil(t_max * 4/5 * wfm.sample_rate))  # get number of samples required for longest move, this sets the size of lookup table
    sample_len_max += (512 - sample_len_max % 512)  # make overall length a multiple of 512 so AWG doesn't freak out

    # now we calculate all possible trajectories, go to Group Notes/Projects/Rearrangement for detail
    n = len(wfm.omega)  # total number of tweezers
    phi_paths = np.zeros((n, n, sample_len_max))  # lookup table to store all moves
    t = np.arange(sample_len_max) / wfm.sample_rate  # time series
    # iterate! I think this part can be vectorized as well... but unnecessary.
    for i, omega_i in enumerate(wfm.omega):
        for j, omega_j in enumerate(wfm.omega):  # j is the target position, i is starting position
            if i == j:
                phi_paths[i,i] = omega_i*t + wfm.phi[i]
                continue  # skip diagonal entries
            dw = omega_j - omega_i  # delta omega in the equation
            adw = abs(dw)

            # I advise reading through the notes page first before going further
            phi_j = wfm.phi[j] % twopi  # wrap around two pi
            phi_i = wfm.phi[i] % twopi
            dphi = phi_j - phi_i  # delta phi in the equation
            if dphi < 0: dphi = abs(dphi) + twopi - phi_i  # warp around for negative phase shift
            t_tot = np.sqrt(abs(4 * dw / a_max))  # calculate minimum time to complete move
            t_tot = ((t_tot - 6*dphi/adw) // (12*np.pi/adw) + 1) * 12*np.pi/adw  # extend move time to arrive at the correct phase
            a = 4*dw/(t_tot**2)  # adjust acceleration accordingly to ensure we still get to omega_j
            end = int(np.ceil(t_tot * wfm.sample_rate))  # convert to an index in samples
            half = int(end / 2)  # index of sample half-way through the move where equation changes
            t1 = t[:half]  # first half of the move, slicing to make life easier
            t2 = t[half:end] - t_tot/2  # time series for second half of the move

            # do calculation
            phi_paths[i,j, :half] = wfm.phi[i] + omega_i*t1 + a/6*t1**3  # t<=T/2
            phi_paths[i,j, half:end] = phi_paths[i,j,half-1] + (omega_i+a/2*(t_tot/2)**2)*t2 + a/2*t_tot/2*t2**2 - a/6*t2**3  # t>=T/2
            phi_paths[i,j, end:] = omega_j*t[end:] + (phi_j - omega_j*t_tot) % twopi  # fill the rest with parameters of target wave

    # now compile everything into sine wave
    phi_paths = wfm.amplitude * np.sin(phi_paths)
    return phi_paths.astype(int), np.sum(phi_paths.diagonal().T, axis=0, dtype=int)


def create_moving_array(sig: np.ndarray, path_table: np.ndarray, paths: np.ndarray) -> np.ndarray:
    """
    create a rearranging signal that moves tweezers as specified by paths
    :param wfm: waveform object already initialized with basic parameters.
    :param sin_mat: 2d array where ith entry contains static sin wave for ith tweezer. See create_static_array for detail.
    :param path_table: lookup table returned from create_path_table().
    :param paths: 1d array filled with tuples indicating moving trajectories. Example: np.array([(1,0),(2,1)]) moves tweezer1->0,tweezer2->1
    :return: 1D array with rearrangement-generating waveform.
    """
    # for i,j in paths:
    #     sin_mat[i] = phi_paths[i,j]
    # fyi, line below is equivalent to the for loop above
    # sin_mat[paths[:,0]] = path_table[paths[:,0], paths[:,1]]  # copy dynamic trajectories from path_table,
    # sin_mat[np.setdiff1d(paths[:,1], paths[:,0])] = 0  # turn off tweezers that need to be turned off, moving 3-2, 2-1, 1-0 will turn off 0.
    # return np.sum(sin_mat, axis=0)  # sum up all rows to get final signal
    # return np.sum(sin_mat, axis=0), sin_mat/wfm.amplitude
    sig += -np.sum(path_table[paths[:,0], paths[:,0]], axis=0) + np.sum(path_table[paths[:,0], paths[:,1]], axis=0)


def old_code():
    """
    collection of unused code that might be useful later...
    """
    # def create_phase_paths_old(wfm):
    #     # setup basic variables
    #     vmax = KILO(20) * MEGA(1)  # 20 kHz/us -> 20 kHz * 1e6 / s
    #     dw_max = wfm.omega[-1] - wfm.omega[0]  # Longest move in frequency
    #     t_max = 2 * dw_max / vmax  # Longest move sets the maximum moving time
    #     a = -vmax * 2 / t_max  # constant acceleration, negative sign because of magic
    #     sample_len = int(np.ceil(t_max * wfm.sample_rate))  # get number of samples required for longest move
    #     sample_len += (512 - sample_len % 512)  # make overall length a multiple of 512
    #     # t = np.zeros(padding+sample_len)
    #     t = np.arange(sample_len) / wfm.sample_rate  # get time series
    #
    #     # generate phi for every possible move
    #     n = len(wfm.omega)  # total number of tweezers
    #     phi_paths = np.zeros((n, n, sample_len))  # map to store all moves
    #     for i, omega_i in enumerate(wfm.omega):
    #         for j, omega_j in enumerate(wfm.omega):  # I set j to be the target position, i to be starting position
    #             if i == j: continue  # skil diagonal entries
    #             dw = omega_j - omega_i  # delta omega
    #             t_tot = np.sqrt(abs(4 * dw / a))  # total time to travel dw, t_tot <= t_max
    #             end = int(np.ceil(t_tot * wfm.sample_rate))  # total number of samples to move dw
    #             half = int(end / 2)  # index of sample half-way through the move
    #             t1 = t[:half]  # time series for first half of the move
    #             t2 = t[half:end] - t_tot/2  # time series for second half of the move
    #             # do calculation
    #             phi_paths[i,j, :half] = wfm.phi[i] + omega_i*t1 + a/6*t1**3  # t<=T/2 note we are changing the ith tweezer
    #             phi_paths[i,j, half:end] = phi_paths[i,j,half-1] + (omega_i+a/2*(t_tot/2)**2)*t2 + a/2*t_tot/2*t2**2 - a/6*t2**3  # t>=T/2
    #             # phi_paths[i,j, half:end] = phi_paths[i,j,half-1] + (omega_i+a*(t_tot/2)**2)*t2 + a/2*t_tot/2*t2**2 - a/6*t2**3  # t>=T/2
    #             phi_paths[i,j, end:] = omega_j*t[end:]  # fill the rest of the array with target frequency wave
    #             # phi_paths[i,j, end:] = phi_paths[i,j, end-1]*t[end:]  # fill the rest of the array with same value
    #     return phi_paths

    # sig_mat = wfm.amplitude * np.sin(np.outer(wfm.omega,t) + wfm.phi[:, np.newaxis])  # shape=(number of tweezers x sample_len)

    # self.debug["mat1"] = np.zeros((n, n, 2))  # for phase debugging
    # for diagnostic purposes
    # path_idx[i, j] = (t_tot, end)
    # phi_const[i,j, :half] = wfm.phi[i]
    # phi_const[i,j, half:end] = phi_const[i,j,half-1] +
    # print(a, t_tot)
    # return phi_paths, wfm.amplitude * np.sin(np.outer(wfm.omega,t) + np.expand_dims(wfm.phi, axis=1)), path_idx

    pass


def main():
    # sample usage
    np.random.seed(0)
    cf = MEGA(80)
    df = MEGA(1)
    n = 3
    # period = 1/cf
    sample = 512*1000
    rate = MEGA(625)
    # rate = sample*cf/1e5
    print(f"Sampling rate: {rate/1e6}MHz")
    print(f"center f: {cf/1e6}MHz\n"
          f"df: {df/1e6}MHz\n"
          f"total n: {2*n+1}")

    wave = Waveform(cf, df, n, rate)
    static_sig = create_static_array(wave, sample)
    static_sig = static_sig.astype(int)

    table, move_sig = create_path_table(wave)
    repath = np.array([(1,0), (2,1), (3,2)])
    create_moving_array(move_sig, table, repath)

    np.savetxt("data/static_signal.txt", static_sig, delimiter=',')
    np.savetxt("data/move_signal.txt", move_sig, delimiter=',')
    # np.savetxt("initial_phi.txt", wave.phi, delimiter=',')
    # np.savetxt("sig_mat.txt", sig_mat, delimiter=',')


def size_test(n):
    # sample usage
    np.random.seed(0)
    cf = MEGA(80)
    df = MEGA(1)
    rate = MEGA(625)
    wave = Waveform(cf, df, n, rate)
    table = create_path_table(wave)
    return table.nbytes


# main()