Skip to content
Snippets Groups Projects
array_analysis.py 22.16 KiB
from __future__ import annotations
from pathlib import Path
from itertools import product
import os
import shutil
import sys
import numpy as np
import matplotlib.pyplot as pp
import lib.image as image
from lib.image import ROI, S
import lib.params as params
from lib.params import (ParamSpec as PS, load_params, DataSet)
import lib.plotting as plotting

def qq(X):
    print(X)
    return X

datadir = Path(r"C:\Users\Covey Lab\Documents\Andor Solis\atomic_data")
date = "20230207"

infiles = [
    datadir.joinpath(date).joinpath(infile)
    for infile in [
        "zero-rabi-hist_015.fits",
    ]
]

exptdir = Path(r"\\ANACONDA\Users\EW\Documents\Data\tweezer atoms")
# select variables to track from teh driving script
# all variables are converted to lower case
# the order of this list must strictly follow the scan order
# each PS == ParamSpec follows the format
#   PS(name: str, mapping_func: lambda, unit: str)
# with the last two optional
paramslist = [
    ### const
    # PS("dispenser", "A"),

    ### init
    # PS("t0", "ms", lambda t: 1000.0 * t),
    # PS("u_nominal", "uK"),

    ### cooling - serial
    PS("p_cool", "Isat"),
    PS("det_cool", "MHz"),

    ### probing - serial
    # PS("p_probe", "dBm"),
    PS("det_probe", "MHz"),

    ### cmot
    # PS("shims_cmot_fb", "V"),
    # PS("shims_cmot_lr", "V"),
    # PS("shims_cmot_ud", "V"),

    ### load
    # PS("shims_smear_fb", "V"),
    # PS("shims_smear_lr", "V"),
    # PS("shims_smear_ud", "V"),
    # PS("tau_load", "ms", lambda t: 1000.0 * t),

    ### cool
    # PS("shims_cool_fb", "G"),
    # PS("shims_cool_lr", "G"),
    # PS("shims_cool_ud", "G"),
    PS("tau_cool", "ms", lambda t: 1000.0 * t),

    ### pump
    # PS("u_pump", "uK"),
    # PS("shims_pump_fb", "G"),
    # PS("shims_pump_lr", "G"),
    # PS("shims_pump_ud", "G"),
    # PS("hholtz_pump", "G"),
    # PS("tau_pump", "ms", lambda t: 1000.0 * t),
    # PS("p_pump", "'Isat'"),
    # PS("det_pump", "MHz"),

    ### mag
    # PS("shims_mag_fb", "G"),
    # PS("shims_mag_lr", "G"),
    PS("shims_mag_ud", "G"),
    # PS("shims_mag_amp_fb", "G"),
    PS("shims_mag_amp_lr", "G"),
    PS("tau_mag", "ms", lambda t: 1000.0 * t),
    PS("f_mag", "kHz", lambda f: 1000.0 * f),

    ### kill
    # PS("shims_kill_fb", "G"),
    # PS("shims_kill_lr", "G"),
    # PS("shims_kill_ud", "G"),
    # PS("hholtz_kill", "G"),
    # PS("tau_kill", "ms", lambda t: 1000.0 * t),
    # PS("p_kill", "'Isat'"),
    # PS("det_kill", "MHz"),

    ### test
    # PS("u_test", "uK"),
    # PS("shims_test_fb", "G"),
    # PS("shims_test_lr", "G"),
    # PS("shims_test_ud", "G"),
    # PS("hholtz_test", "G"),
    # PS("tau_test", "ms", lambda t: 1000.0 * t),
    # PS("p_test_cmot_beg", "dBm"),
    # PS("p_test_cmot_end", "dBm"),
    # PS("det_test_cmot_beg", "MHz"),
    # PS("det_test_cmot_end", "MHz"),
    # PS("p_test_probe_am", "Isat"),

    ### lifetime
    # PS("u_lifetime", "uK"),
    # PS("shims_lifetime_fb", "G"),
    # PS("shims_lifetime_lr", "G"),
    # PS("shims_lifetime_ud", "G"),
    # PS("hholtz_lifetime", "G"),
    # PS("tau_lifetime", "ms", lambda t: 1000.0 * t),
    # PS("p_lifetime_cmot_beg", "dBm"),
    # PS("p_lifetime_cmot_end", "dBm"),
    # PS("det_lifetime_cmot_beg", "MHz"),
    # PS("det_lifetime_cmot_end", "MHz"),
    # PS("p_lifetime_probe_am", "Isat"),

    ### rampdown
    PS("u_rampdown", "uK"),
    # PS("tau_rampdown", "ms", lambda t: 1000.0 * t)

    ### release-recapture
    # PS("tau_tof", "us", lambda t: 1e6 * t),

    ### rampdown 2
    # PS("u_rampdown_2", "uK"),

    ### image
    # PS("shims_probe_fb", "G"),
    # PS("shims_probe_lr", "G"),
    # PS("shims_probe_ud", "G"),
    PS("hholtz_probe", "G"),
    PS("tau_probe", "ms", lambda t: 1000.0 * t),
    PS("p_probe_am", "Isat"),
]

# extra variables not listed above
# items follow the format
# parameter_name: str | (str str) -> value: float | numpy.ndarray[ndim=1]
extra_params = {
}

### camera settings
readout_rate: float = 1.0 # MHz
preamp: int = 2
em_gain: int = 100
count_bias: float = 500.0
QE: float = 0.8

### data selection options
roi_dim: list[int, 2] = [3, 3] # [ w, h ]
roi_locs: list[list[int, 2]] = [ # [ x (j), y (i) ]
    [11, 3],
    # [16, 4],
    # [42, 57],
]
num_images: int = 3 # expect two shots per param config
optim_pad: int = 0 # additional padding area for ROI optimization
hist_bin_size: int = 3
threshold: float = 7.0

### processing options
# take only this number of shots or fewer (-1 guarantees the whole stack)
take: int = -1
# delete pre-existing sub-directories and files generated by this script
renew_files: bool = True
# umbrella switch for all plotting
do_plotting: bool = True
# switch for only plotting related to individual parameter configs
do_indiv_plotting: bool = True
render_whole_image: bool = True
render_subimages: bool = False
plot_hist_totals: bool = True
draw_roi_guides: bool = True
# sort parameter axes according to parameter array values
sort_axes: bool = True
# plot against two variables using slices instead of a color plot
force_lines: bool = True
# plot MPC/FAT/survival as slices versus these parameters (must be 1 or 2)
plot_versus: list[str] = ["tau_mag"]
# for multiline plots, plot versus this `plot_versus` axis (must be 0 or 1)
multiline_plot_versus: int = 0
# plot info on this shot
shot_num: int = 1
# add violin plots to MPC lineplots
violins_mpc: bool = False
# normalize MPC line plots to the max along the plotting axis
normalize_mpc: bool = False
# plot histograms/MPC/FAT for each ROI independently
# (requires len(plot_versus) == 1)
indep_rois: bool = False
plot_cmap = "jet"
image_cmap = "gray"

### derived quantities -- do not edit directly
rois: list[ROI] = [ROI(loc, roi_dim, optim_pad) for loc in roi_locs]
plot_versus_axes = [paramslist.index(p) for p in plot_versus]
if len(plot_versus_axes) not in {1, 2}:
    print("can only plot against 1 or 2 axes")
    sys.exit(1)

################################################################################

def process_file(filepath: Path, printflag: bool=True):
    if printflag:
        print(f"read image data from {filepath}")
    image_data = image.load_image(filepath.parent, filepath.name)
    if take is not None and take > 0:
        image_data = image_data[:take, :, :]
    print(image_data.shape)

    paramsdir = exptdir.joinpath(date).joinpath(filepath.stem)
    if printflag:
        print(f"read params data from {paramsdir.joinpath('params.toml')}")
    params, reps = load_params(
        paramsdir,
        "params.toml",
        paramslist,
        extra_params
    )

    outdir = filepath.parent.joinpath(filepath.stem)
    if outdir.is_dir() and renew_files:
        already = os.listdir(outdir)
        for x in already:
            if x in {
                "roi_totals.npz",
                "fat_data.npz",
                "mpc_data.npz",
                "survival_data.npz"
            }:
                target = outdir.joinpath(x)
                print(f":: rm {target}")
                os.remove(str(target))
            elif x in {
                "fat",
                "hist",
                "img",
                "mpc",
                "sequencing",
                "survival",
            }:
                target = outdir.joinpath(x)
                print(f":: rm -rf {target}")
                shutil.rmtree(str(target), ignore_errors=True)
    elif not outdir.is_dir():
        print(f":: mkdir -p {outdir}")
        outdir.mkdir(parents=True)

    sequencingdir = outdir.joinpath("sequencing")
    if not sequencingdir.is_dir():
        print(f":: mkdir -p {sequencingdir}")
        sequencingdir.mkdir(parents=True)
    shutil.copy(paramsdir.joinpath("params.toml"), sequencingdir)
    shutil.copy(paramsdir.joinpath("comments.txt"), sequencingdir)
    # shutil.copy(paramsdir.joinpath("order.toml"), sequencingdir)

    if printflag:
        print(f"compute derived quantities")
    image_photons = image.to_photons(
        image_data,
        readout_rate,
        preamp,
        em_gain,
        count_bias,
        QE,
    )
    dataset = DataSet.from_raw(
        params,
        reps,
        num_images,
        np.array(image_photons),
        rois,
        threshold,
        sort_axes,
    )

    if printflag:
        print("write data to files")
    np.savez(
        str(outdir.joinpath("roi_totals.npz")),
        **{p[0]: v for p, v in dataset.params.items()},
        roi_totals=dataset.roi_totals[0],
        roi_totals_err=dataset.roi_totals[1],
    )
    np.savez(
        str(outdir.joinpath("mpc_data.npz")),
        **{p[0]: v for p, v in dataset.params.items()},
        mpc=dataset.mpc[0],
        mpc_err=dataset.mpc[1],
    )
    np.savez(
        str(outdir.joinpath("fat_data.npz")),
        **{p[0]: v for p, v in dataset.params.items()},
        fat=dataset.fat[0],
        fat_err=dataset.fat[1],
    )
    if dataset.survival is not None:
        np.savez(
            str(outdir.joinpath("survival_data.npz")),
            **{p[0]: v for p, v in dataset.params.items()},
            survival=dataset.survival[0],
            survival_err=dataset.survival[1],
        )

    if do_plotting:
        outdir_mpc = outdir.joinpath("mpc")
        outdir_fat = outdir.joinpath("fat")
        outdir_survival = outdir.joinpath("survival")
        for sub_outdir, flag in [
            ( outdir_mpc,      True           and do_plotting ),
            ( outdir_fat,      True           and do_plotting ),
            ( outdir_survival, num_images > 1 and do_plotting ),
        ]:
            if flag and not sub_outdir.is_dir():
                print(f":: mkdir -p {sub_outdir}")
                sub_outdir.mkdir(parents=True)

        outdir_img = outdir.joinpath("img")
        outdir_subimg = outdir.joinpath("subimg")
        outdir_hist = outdir.joinpath("hist")
        for sub_outdir, flag in [
            ( outdir_img,    render_whole_image and do_indiv_plotting ),
            ( outdir_subimg, render_subimages   and do_indiv_plotting ),
            ( outdir_hist,   plot_hist_totals   and do_indiv_plotting ),
        ]:
            if flag and not sub_outdir.is_dir():
                print(f":: mkdir -p {sub_outdir}")
                sub_outdir.mkdir(parents=True)

        if printflag:
            print("render output visuals")
        newline = "\n"
        rspace = sum(int(np.log10(len(X)) + 1) for X in dataset.params.values())

        ####################################

        if printflag:
            print("  aggregated quantities:")
        fmt = " | ".join(
            f"{p[0]}={{:+.3f}}{p[1]}{newline if k % 4 == 3 else ''}"
            for k, p in enumerate(
                p_ for k_, p_ in dataset.params.keys_indexed()
                if k_ not in plot_versus_axes
            )
        )
        pngname = lambda *V: (
            fmt.format(*V)
                .replace("\n", "")
                .replace("|", "")
                .replace("_", "-")
                .replace("  ", "_")
                .replace(" ", "")
            + ".png"
        )
        plotvals, iterlabels, iterator = dataset.iter_agg(
            plot_versus_axes,
            shot_num,
        )
        pltlabels, pltvalues = zip(*plotvals)
        for Q, V, roi_totals, mpc, fat, survival in iterator:
            if printflag:
                print("\r  ", *Q, rspace * " ", end="", flush=True)

            if len(pltlabels) == 1:
                fig, ax = plotting.lineplot(
                    x=pltvalues[0],
                    y=mpc[0],
                    err=mpc[1],
                    pop=roi_totals[0] if violins_mpc else None,
                    xlabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
                    ylabel="Mean photon count",
                    title=fmt.format(*V),
                    normalize=normalize_mpc,
                    indep_rois=indep_rois,
                )
                fig.savefig(outdir_mpc.joinpath(pngname(*V)))
                pp.close(fig)

                fig, ax = plotting.lineplot(
                    x=pltvalues[0],
                    y=fat[0],
                    err=fat[1],
                    pop=None,
                    xlabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
                    ylabel="Fraction above threshold",
                    title=fmt.format(*V),
                    normalize=False,
                    indep_rois=indep_rois,
                )
                fig.savefig(outdir_fat.joinpath(pngname(*V)))
                pp.close(fig)

                if survival is not None:
                    fig, ax = plotting.lineplot(
                        x=pltvalues[0],
                        y=survival[0],
                        err=survival[1],
                        pop=None,
                        xlabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
                        ylabel="Survival fraction",
                        title=fmt.format(*V),
                        normalize=False,
                        indep_rois=indep_rois,
                    )
                    fig.savefig(outdir_survival.joinpath(pngname(*V)))
                    pp.close(fig)

            elif force_lines:
                fig, ax = plotting.multilineplot(
                    x=pltvalues[1],
                    y=pltvalues[0],
                    Z=mpc[0].mean(axis=0),
                    ERR=np.sqrt((mpc[1]**2).sum(axis=0)) / mpc.shape[1],
                    POP=roi_totals[0].mean(axis=0) if violins_mpc else None,
                    xlabel=f"{pltlabels[1][0]} [{pltlabels[1][1]}]",
                    ylabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
                    zlabel="Mean photon count",
                    title=fmt.format(*V),
                    versus_axis=multiline_plot_versus,
                    normalize=normalize_mpc,
                )
                fig.savefig(outdir_mpc.joinpath(pngname(*V)))
                pp.close(fig)

                fig, ax = plotting.multilineplot(
                    x=pltvalues[1],
                    y=pltvalues[0],
                    Z=fat[0].mean(axis=0),
                    ERR=np.sqrt((fat[1]**2).sum(axis=0)) / fat.shape[1],
                    POP=None,
                    xlabel=f"{pltlabels[1][0]} [{pltlabels[1][1]}]",
                    ylabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
                    zlabel="Fraction above threshold",
                    title=fmt.format(*V),
                    versus_axis=multiline_plot_versus,
                    normalize=normalize_mpc,
                )
                fig.savefig(outdir_fat.joinpath(pngname(*V)))
                pp.close(fig)

                if survival is not None:
                    fig, ax = plotting.multilineplot(
                        x=pltvalues[1],
                        y=pltvalues[0],
                        Z=survival[0].mean(axis=0),
                        ERR=np.sqrt((survival[1]**2).sum(axis=0))
                            / survival.shape[1],
                        POP=None,
                        xlabel=f"{pltlabels[1][0]} [{pltlabels[1][1]}]",
                        ylabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
                        zlabel="Survival fraction",
                        title=fmt.format(*V),
                        versus_axis=multiline_plot_versus,
                        normalize=normalize_mpc,
                    )
                    fig.savefig(outdir_survival.joinpath(pngname(*V)))
                    pp.close(fig)

            else:
                (fig, ax), (fig_err, ax_err) = plotting.colorplot(
                    x=pltvalues[1],
                    y=pltvalues[0],
                    Z=mpc[0].mean(axis=0),
                    ERR=np.sqrt((mpc[1]**2).sum(axis=0)) / mpc.shape[1],
                    cmap=plot_cmap,
                    xlabel=f"{pltlabels[1][0]} [{pltlabels[1][1]}]",
                    ylabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
                    clabel="Mean photon count",
                    title=fmt.format(*V),
                )
                fig.savefig(outdir_mpc.joinpath(pngname(*V)))
                fig_err.savefig(
                    outdir_mpc.joinpath(
                        Path(pngname(*V)).stem + "_err.png"
                    )
                )
                pp.close(fig)
                pp.close(fig_err)

                (fig, ax), (fig_err, ax_err) = plotting.colorplot(
                    x=pltvalues[1],
                    y=pltvalues[0],
                    Z=fat[0].mean(axis=0),
                    ERR=np.sqrt((fat[1]**2).sum(axis=0)) / fat.shape[1],
                    cmap=plot_cmap,
                    xlabel=f"{pltlabels[1][0]} [{pltlabels[1][1]}]",
                    ylabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
                    clabel="Fraction above threshold",
                    title=fmt.format(*V),
                )
                fig.savefig(outdir_fat.joinpath(pngname(*V)))
                fig_err.savefig(
                    outdir_fat.joinpath(
                        Path(pngname(*V)).stem + "_err.png"
                    )
                )
                pp.close(fig)
                pp.close(fig_err)

                if survival is not None:
                    (fig, ax), (fig_err, ax_err) = plotting.colorplot(
                        x=pltvalues[1],
                        y=pltvalues[0],
                        Z=survival[0].mean(axis=0),
                        ERR=np.sqrt((survival[1]**2).sum(axis=0))
                            / survival.shape[1],
                        cmap=plot_cmap,
                        xlabel=f"{pltlabels[1][0]} [{pltlabels[1][1]}]",
                        ylabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
                        clabel="Survival fraction",
                        title=fmt.format(*V),
                    )
                    fig.savefig(outdir_survival.joinpath(pngname(*V)))
                    fig_err.savefig(
                        outdir_survival.joinpath(
                            Path(pngname(*V)).stem + "_err.png"
                        )
                    )
                    pp.close(fig)
                    pp.close(fig_err)

        if printflag:
            print("")

        ####################################

        if printflag:
            print("  populational quantities:")
        fmt = " | ".join(
            f"{p[0]}={{:.3f}}{p[1]}{newline if k % 4 == 3 else ''}"
            for k, p in dataset.params.keys_indexed()
        )
        pngname = lambda *V: (
            fmt.format(*V)
                .replace("\n", "")
                .replace("|", "")
                .replace("_", "-")
                .replace("  ", "_")
                .replace(" ", "")
            + ".png"
        )
        image_vmax = dataset.img_avg[(
            *(len(paramslist) * [S[:]]), int(shot_num), S[:], S[:]
        )].max()
        subimage_vmax = dataset.roi_img_avg[(
            S[:], *(len(paramslist) * [S[:]]), int(shot_num), S[:], S[:]
        )].max()
        hist_vmax = dataset.roi_totals[(
            0, S[:], S[:], *(len(paramslist) * [S[:]]), int(shot_num)
        )].max()
        hist_vmin = dataset.roi_totals[(
            0, S[:], S[:], *(len(paramslist) * [S[:]]), int(shot_num)
        )].min()
        iterlabels, iterator = dataset.iter_indiv(shot_num)
        for Q, V, img, locs, subimgs, roi_totals in iterator:
            if printflag:
                print("\r  ", *Q, rspace * " ", end="", flush=True)

            # whole image
            if render_whole_image and do_indiv_plotting:
                fig, ax = plotting.render_image(
                    img=img,
                    cmap=image_cmap,
                    vmin=0.0,
                    vmax=image_vmax,
                    clabel="Photons",
                    title=fmt.format(*V),
                    draw_guides=draw_roi_guides,
                    rois=rois,
                    roi_optim_locs=locs,
                )
                fig.savefig(outdir_img.joinpath(pngname(*V)))
                pp.close(fig)

            # render each sub-image
            if render_subimages and do_indiv_plotting:
                fig, axs = plotting.render_subimages(
                    imgs=subimgs,
                    cmap=image_cmap,
                    vmin=0.0,
                    vmax=subimage_vmax,
                    clabel="Photons",
                    title=fmt.format(*V),
                    draw_guides=draw_roi_guides,
                )
                fig.savefig(outdir_subimg.joinpath(pngname(*V)))
                pp.close(fig)

            # histogram 
            if plot_hist_totals and do_indiv_plotting:
                fig, ax = plotting.histogram(
                    data=roi_totals[0],
                    mean=roi_totals[0].mean(),
                    mean_err=np.sqrt((roi_totals[1]**2).sum())
                        / roi_totals.shape[1] / roi_totals.shape[2],
                    threshold=threshold,
                    bins=np.arange(
                        min(hist_vmin, -3 * hist_bin_size),
                        hist_vmax + hist_bin_size,
                        hist_bin_size
                    ),
                    density=True,
                    xlabel="Photons",
                    ylabel="Prob. density",
                    title=fmt.format(*V),
                )
                # ax.set_xlim(hist_vmin, roi_totals[0].max() + 2 * hist_bin_size)
                fig.savefig(outdir_hist.joinpath(pngname(*V)))
                pp.close(fig)

                if indep_rois:
                    fig, ax = plotting.subimages_histogram(
                        data=roi_totals[0],
                        means=roi_totals[0].mean(axis=1),
                        means_err=np.sqrt((roi_totals[1]**2).sum(axis=1))
                            / roi_totals.shape[2],
                        threshold=threshold,
                        bins=np.arange(
                            min(hist_vmin, -3 * hist_bin_size),
                            hist_vmax + hist_bin_size,
                            hist_bin_size
                        ),
                        density=True,
                        xlabel="Photons",
                        ylabel="Prob. density",
                        title=fmt.format(*V),
                    )
                    fig.savefig(outdir_hist.joinpath(
                        pngname(*V).replace(".png", "_indep.png")
                    ))
                    pp.close(fig)

        if printflag:
            print("")

    if printflag:
        print("done.")

if __name__ == "__main__":
    for infile in infiles:
        process_file(infile)