Skip to content
Snippets Groups Projects
Commit 496431a5 authored by camera computer's avatar camera computer
Browse files

small indexing bugfixes, tweak plotting

parent 2918b32e
No related branches found
No related tags found
No related merge requests found
**/__pycache__/* **/__pycache__/*
**.png **.png
**.svg **.svg
output/*
cmot-freq-scan.py
...@@ -13,12 +13,12 @@ from lib.params import (ParamSpec as PS, load_params, DataSet) ...@@ -13,12 +13,12 @@ from lib.params import (ParamSpec as PS, load_params, DataSet)
import lib.plotting as plotting import lib.plotting as plotting
datadir = Path(r"C:\Users\Covey Lab\Documents\Andor Solis\atomic_data") datadir = Path(r"C:\Users\Covey Lab\Documents\Andor Solis\atomic_data")
date = "20221028" date = "20221110"
infiles = [ infiles = [
datadir.joinpath(date).joinpath(infile) datadir.joinpath(date).joinpath(infile)
for infile in [ for infile in [
"ipc_012.fits", "probe-power-scan_007.fits",
] ]
] ]
...@@ -111,7 +111,7 @@ paramslist = [ ...@@ -111,7 +111,7 @@ paramslist = [
# PS("p_lifetime_probe_am", "Isat"), # PS("p_lifetime_probe_am", "Isat"),
### rampdown ### rampdown
# PS("u_rampdown", "uK"), PS("u_rampdown", "uK"),
# PS("tau_rampdown", "ms", lambda t: 1000.0 * t) # PS("tau_rampdown", "ms", lambda t: 1000.0 * t)
### release-recapture ### release-recapture
...@@ -144,16 +144,13 @@ QE: float = 0.8 ...@@ -144,16 +144,13 @@ QE: float = 0.8
### data selection options ### data selection options
roi_dim: list[int, 2] = [3, 3] # [ w, h ] roi_dim: list[int, 2] = [3, 3] # [ w, h ]
roi_locs: list[list[int, 2]] = [ # [ x (j), y (i) ] roi_locs: list[list[int, 2]] = [ # [ x (j), y (i) ]
[50, 47], [26, 3],
[51, 55], # [49, 63],
[52, 64],
[53, 73],
[54, 81],
] ]
is_prepost: bool = False # expect two shots per param config is_prepost: bool = True # expect two shots per param config
optim_pad: int = 0 # additional padding area for ROI optimization optim_pad: int = 0 # additional padding area for ROI optimization
hist_bin_size: int = 5 hist_bin_size: int = 5
threshold: float = 30.0 threshold: float = 25
### processing options ### processing options
# delete pre-existing sub-directories and files generated by this script # delete pre-existing sub-directories and files generated by this script
...@@ -171,7 +168,7 @@ sort_axes: bool = True ...@@ -171,7 +168,7 @@ sort_axes: bool = True
# plot against two variables using slices instead of a color plot # plot against two variables using slices instead of a color plot
force_lines: bool = True force_lines: bool = True
# plot MPC/FAT/survival as slices versus these parameters (must be 1 or 2) # plot MPC/FAT/survival as slices versus these parameters (must be 1 or 2)
plot_versus: list[str] = ["tau_probe",] plot_versus: list[str] = ["p_probe_am"]
# for multiline plots, plot versus this `plot_versus` axis (must be 0 or 1) # for multiline plots, plot versus this `plot_versus` axis (must be 0 or 1)
multiline_plot_versus: int = 0 multiline_plot_versus: int = 0
# plot pre/post shots: False => pre; True => post # plot pre/post shots: False => pre; True => post
...@@ -182,7 +179,7 @@ violins_mpc: bool = False ...@@ -182,7 +179,7 @@ violins_mpc: bool = False
normalize_mpc: bool = False normalize_mpc: bool = False
# plot histograms/MPC/FAT for each ROI independently # plot histograms/MPC/FAT for each ROI independently
# (requires len(plot_versus) == 1) # (requires len(plot_versus) == 1)
indep_rois: bool = True indep_rois: bool = False
plot_cmap = "jet" plot_cmap = "jet"
image_cmap = "gray" image_cmap = "gray"
...@@ -358,7 +355,7 @@ def process_file(filepath: Path, printflag: bool=True): ...@@ -358,7 +355,7 @@ def process_file(filepath: Path, printflag: bool=True):
y=mpc[0], y=mpc[0],
err=mpc[1], err=mpc[1],
pop=roi_totals[0] if violins_mpc else None, pop=roi_totals[0] if violins_mpc else None,
xlabel=pltlabels[0], xlabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
ylabel="Mean photon count", ylabel="Mean photon count",
title=fmt.format(*V), title=fmt.format(*V),
normalize=normalize_mpc, normalize=normalize_mpc,
...@@ -372,7 +369,7 @@ def process_file(filepath: Path, printflag: bool=True): ...@@ -372,7 +369,7 @@ def process_file(filepath: Path, printflag: bool=True):
y=fat[0], y=fat[0],
err=fat[1], err=fat[1],
pop=None, pop=None,
xlabel=pltlabels[0], xlabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
ylabel="Fraction above threshold", ylabel="Fraction above threshold",
title=fmt.format(*V), title=fmt.format(*V),
normalize=False, normalize=False,
...@@ -387,7 +384,7 @@ def process_file(filepath: Path, printflag: bool=True): ...@@ -387,7 +384,7 @@ def process_file(filepath: Path, printflag: bool=True):
y=survival[0], y=survival[0],
err=survival[1], err=survival[1],
pop=None, pop=None,
xlabel=pltlabels[0], xlabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
ylabel="Survival fraction", ylabel="Survival fraction",
title=fmt.format(*V), title=fmt.format(*V),
normalize=False, normalize=False,
...@@ -403,8 +400,8 @@ def process_file(filepath: Path, printflag: bool=True): ...@@ -403,8 +400,8 @@ def process_file(filepath: Path, printflag: bool=True):
Z=mpc[0].mean(axis=0), Z=mpc[0].mean(axis=0),
ERR=np.sqrt((mpc[1]**2).sum(axis=0)) / mpc.shape[1], ERR=np.sqrt((mpc[1]**2).sum(axis=0)) / mpc.shape[1],
POP=roi_totals[0].mean(axis=0) if violins_mpc else None, POP=roi_totals[0].mean(axis=0) if violins_mpc else None,
xlabel=pltlabels[1], xlabel=f"{pltlabels[1][0]} [{pltlabels[1][1]}]",
ylabel=pltlabels[0], ylabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
zlabel="Mean photon count", zlabel="Mean photon count",
title=fmt.format(*V), title=fmt.format(*V),
versus_axis=multiline_plot_versus, versus_axis=multiline_plot_versus,
...@@ -419,8 +416,8 @@ def process_file(filepath: Path, printflag: bool=True): ...@@ -419,8 +416,8 @@ def process_file(filepath: Path, printflag: bool=True):
Z=fat[0].mean(axis=0), Z=fat[0].mean(axis=0),
ERR=np.sqrt((fat[1]**2).sum(axis=0)) / fat.shape[1], ERR=np.sqrt((fat[1]**2).sum(axis=0)) / fat.shape[1],
POP=None, POP=None,
xlabel=pltlabels[1], xlabel=f"{pltlabels[1][0]} [{pltlabels[1][1]}]",
ylabel=pltlabels[0], ylabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
zlabel="Fraction above threshold", zlabel="Fraction above threshold",
title=fmt.format(*V), title=fmt.format(*V),
versus_axis=multiline_plot_versus, versus_axis=multiline_plot_versus,
...@@ -437,8 +434,8 @@ def process_file(filepath: Path, printflag: bool=True): ...@@ -437,8 +434,8 @@ def process_file(filepath: Path, printflag: bool=True):
ERR=np.sqrt((survival[1]**2).sum(axis=0)) ERR=np.sqrt((survival[1]**2).sum(axis=0))
/ survival.shape[1], / survival.shape[1],
POP=None, POP=None,
xlabel=pltlabels[1], xlabel=f"{pltlabels[1][0]} [{pltlabels[1][1]}]",
ylabel=pltlabels[0], ylabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
zlabel="Survival fraction", zlabel="Survival fraction",
title=fmt.format(*V), title=fmt.format(*V),
versus_axis=multiline_plot_versus, versus_axis=multiline_plot_versus,
...@@ -454,8 +451,8 @@ def process_file(filepath: Path, printflag: bool=True): ...@@ -454,8 +451,8 @@ def process_file(filepath: Path, printflag: bool=True):
Z=mpc[0].mean(axis=0), Z=mpc[0].mean(axis=0),
ERR=np.sqrt((mpc[1]**2).sum(axis=0)) / mpc.shape[1], ERR=np.sqrt((mpc[1]**2).sum(axis=0)) / mpc.shape[1],
cmap=plot_cmap, cmap=plot_cmap,
xlabel=pltlabels[1], xlabel=f"{pltlabels[1][0]} [{pltlabels[1][1]}]",
ylabel=pltlabels[0], ylabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
clabel="Mean photon count", clabel="Mean photon count",
title=fmt.format(*V), title=fmt.format(*V),
) )
...@@ -474,8 +471,8 @@ def process_file(filepath: Path, printflag: bool=True): ...@@ -474,8 +471,8 @@ def process_file(filepath: Path, printflag: bool=True):
Z=fat[0].mean(axis=0), Z=fat[0].mean(axis=0),
ERR=np.sqrt((fat[1]**2).sum(axis=0)) / fat.shape[1], ERR=np.sqrt((fat[1]**2).sum(axis=0)) / fat.shape[1],
cmap=plot_cmap, cmap=plot_cmap,
xlabel=pltlabels[1], xlabel=f"{pltlabels[1][0]} [{pltlabels[1][1]}]",
ylabel=pltlabels[0], ylabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
clabel="Fraction above threshold", clabel="Fraction above threshold",
title=fmt.format(*V), title=fmt.format(*V),
) )
...@@ -496,8 +493,8 @@ def process_file(filepath: Path, printflag: bool=True): ...@@ -496,8 +493,8 @@ def process_file(filepath: Path, printflag: bool=True):
ERR=np.sqrt((survival[1]**2).sum(axis=0)) ERR=np.sqrt((survival[1]**2).sum(axis=0))
/ survival.shape[1], / survival.shape[1],
cmap=plot_cmap, cmap=plot_cmap,
xlabel=pltlabels[1], xlabel=f"{pltlabels[1][0]} [{pltlabels[1][1]}]",
ylabel=pltlabels[0], ylabel=f"{pltlabels[0][0]} [{pltlabels[0][1]}]",
clabel="Survival fraction", clabel="Survival fraction",
title=fmt.format(*V), title=fmt.format(*V),
) )
...@@ -595,6 +592,7 @@ def process_file(filepath: Path, printflag: bool=True): ...@@ -595,6 +592,7 @@ def process_file(filepath: Path, printflag: bool=True):
ylabel="Prob. density", ylabel="Prob. density",
title=fmt.format(*V), title=fmt.format(*V),
) )
ax.set_xlim(hist_vmin, roi_totals[0].max() + 2 * hist_bin_size)
fig.savefig(outdir_hist.joinpath(pngname(*V))) fig.savefig(outdir_hist.joinpath(pngname(*V)))
pp.close(fig) pp.close(fig)
......
...@@ -137,7 +137,6 @@ class ROI: ...@@ -137,7 +137,6 @@ class ROI:
ret = np.zeros((*array.shape[:-2], 2, 2), dtype=np.uint32) ret = np.zeros((*array.shape[:-2], 2, 2), dtype=np.uint32)
ret[(*(N_ax * [S[:]]), 0, S[:])] = self.loc ret[(*(N_ax * [S[:]]), 0, S[:])] = self.loc
ret[(*(N_ax * [S[:]]), 1, S[:])] = self.size ret[(*(N_ax * [S[:]]), 1, S[:])] = self.size
print(ret.shape)
return ret return ret
else: else:
return self._optim_locs(sum_array, threshold, len(array.shape) - 2) return self._optim_locs(sum_array, threshold, len(array.shape) - 2)
......
...@@ -332,16 +332,17 @@ class DataSet: ...@@ -332,16 +332,17 @@ class DataSet:
# survival fraction (FAT conditioned on being above threshold in pre) # survival fraction (FAT conditioned on being above threshold in pre)
# [roi, params...] or None if not is_prepost # [roi, params...] or None if not is_prepost
if is_prepost: if is_prepost:
is_pre = roi_totals.T[0].T >= threshold is_pre = roi_totals.T[0].T[0] >= threshold
N_pre = is_pre.sum(axis=1) N_pre = is_pre.sum(axis=1)
with np.errstate(divide="ignore", invalid="ignore"): with np.errstate(divide="ignore", invalid="ignore"):
surv = ( surv = (
(is_pre * (roi_totals.T[1].T >= threshold)).sum(axis=1) (is_pre * (roi_totals.T[1].T[0] >= threshold)).sum(axis=1)
/ N_pre / N_pre
) )
print(surv.shape)
surv_err = np.sqrt((surv**2 + surv) / N_pre) surv_err = np.sqrt((surv**2 + surv) / N_pre)
surv[np.where(surv.isnan() + surv.isinf())] = -1.0 surv[np.where(np.isnan(surv) + np.isinf(surv))] = -0.05
surv_err[np.where(surv_err.isnan() + surv_err.isinf())] = -1.0 surv_err[np.where(np.isnan(surv_err) + np.isinf(surv_err))] = 0.0
survival = np.array([surv, surv_err]) survival = np.array([surv, surv_err])
else: else:
survival = None survival = None
...@@ -475,7 +476,7 @@ class DataSet: ...@@ -475,7 +476,7 @@ class DataSet:
F : numpy.ndarray[ndim=2+M, dtype=numpy.float64] F : numpy.ndarray[ndim=2+M, dtype=numpy.float64]
Fraction above threshold for each ROI. Axis 0 indexes Fraction above threshold for each ROI. Axis 0 indexes
quantity/error; axis 1 the ROI. quantity/error; axis 1 the ROI.
S : numpy.ndarray[ndim=2+N, dtype=numpy.float64] S : numpy.ndarray[ndim=2+M, dtype=numpy.float64]
Survival fraction for each ROI if the data has pre/post Survival fraction for each ROI if the data has pre/post
structure. Axis 0 indexes quantity/error; axis 1 the ROI. structure. Axis 0 indexes quantity/error; axis 1 the ROI.
If the data does not have pre/post structure, this value is If the data does not have pre/post structure, this value is
......
...@@ -50,7 +50,7 @@ def render_image( ...@@ -50,7 +50,7 @@ def render_image(
""" """
rois = list() if rois is None else rois rois = list() if rois is None else rois
roi_optim_locs = np.array([]) if roi_optim_locs is None else roi_optim_locs roi_optim_locs = np.array([]) if roi_optim_locs is None else roi_optim_locs
fig, ax = pp.subplots(figsize=[d / 5 for d in img.shape[::-1]]) \ fig, ax = pp.subplots(figsize=[20.0, 20.0]) \
if figax is None else figax if figax is None else figax
im = ax.imshow( im = ax.imshow(
img, img,
...@@ -301,12 +301,12 @@ def colorplot( ...@@ -301,12 +301,12 @@ def colorplot(
ERR : [val_i, val_j] ERR : [val_i, val_j]
""" """
ERR = np.zeros(Z.shape) if ERR is None else ERR ERR = np.zeros(Z.shape) if ERR is None else ERR
sort_idx_j = np.argsort(x) # sort_idx_j = np.argsort(x)
sort_idx_i = np.argsort(y) # sort_idx_i = np.argsort(y)
x = np.sort(x) # x = np.sort(x)
y = np.sort(y) # y = np.sort(y)
Z = S.sort_idx_nd(Z, tuple(), sort_idx_i, sort_idx_j) # Z = S.sort_idx_nd(Z, tuple(), sort_idx_i, sort_idx_j)
ERR = S.sort_idx_nd(ERR, tuple(), sort_idx_i, sort_idx_j) # ERR = S.sort_idx_nd(ERR, tuple(), sort_idx_i, sort_idx_j)
extent = gen_extent(x, y) extent = gen_extent(x, y)
fig, ax = pp.subplots() \ fig, ax = pp.subplots() \
...@@ -373,12 +373,12 @@ def multilineplot( ...@@ -373,12 +373,12 @@ def multilineplot(
""" """
ERR = np.zeros(Z.shape) if ERR is None else ERR ERR = np.zeros(Z.shape) if ERR is None else ERR
has_pop = POP is not None has_pop = POP is not None
sort_idx_j = np.argsort(x) # sort_idx_j = np.argsort(x)
sort_idx_i = np.argsort(y) # sort_idx_i = np.argsort(y)
x = np.sort(x) # x = np.sort(x)
y = np.sort(y) # y = np.sort(y)
Z = S.sort_idx_nd(Z, tuple(), sort_idx_i, sort_idx_j) # Z = S.sort_idx_nd(Z, tuple(), sort_idx_i, sort_idx_j)
ERR = S.sort_idx_nd(ERR, tuple(), sort_idx_i, sort_idx_j) # ERR = S.sort_idx_nd(ERR, tuple(), sort_idx_i, sort_idx_j)
if has_pop: if has_pop:
POP = S.sort_idx_nd( POP = S.sort_idx_nd(
POP, tuple(), np.arange(POP.shape[0]), sort_idx_i, sort_idx_j) POP, tuple(), np.arange(POP.shape[0]), sort_idx_i, sort_idx_j)
...@@ -413,6 +413,7 @@ def multilineplot( ...@@ -413,6 +413,7 @@ def multilineplot(
label=ylabel label=ylabel
) )
for k, yval in enumerate(y): for k, yval in enumerate(y):
color = cmap_f(k / (len(y) - 1)) if len(y) > 1 else "C0"
if has_pop: if has_pop:
v = ax.violinplot( v = ax.violinplot(
POP[:, k, :], x, POP[:, k, :], x,
...@@ -424,9 +425,15 @@ def multilineplot( ...@@ -424,9 +425,15 @@ def multilineplot(
vi.set_linewidth(0.0) vi.set_linewidth(0.0)
ax.errorbar( ax.errorbar(
x, Z[k, :], ERR[k, :], x, Z[k, :], ERR[k, :],
marker="o", linestyle="-", color=cmap_f(k / (len(y) - 1)), marker="o", linestyle="-", color=color,
label=f"{yval:.5f}", label=f"{yval:.5f}",
) )
if len(x) == 1:
ax.text(
x[0], Z[k, 0], f" ${Z[k, 0]:.5f} \\pm {ERR[k, 0]:.5f}$",
color="k", fontsize="x-small",
ha="left", va="center",
)
else: else:
ax.plot( ax.plot(
[], [], [], [],
...@@ -434,6 +441,7 @@ def multilineplot( ...@@ -434,6 +441,7 @@ def multilineplot(
label=xlabel label=xlabel
) )
for k, xval in enumerate(x): for k, xval in enumerate(x):
color = cmap_f(k / (len(x) - 1)) if len(x) > 1 else "C0"
if has_pop: if has_pop:
v = ax.violinplot( v = ax.violinplot(
POP[:, :, k], y, POP[:, :, k], y,
...@@ -445,9 +453,15 @@ def multilineplot( ...@@ -445,9 +453,15 @@ def multilineplot(
vi.set_linewidth(0.0) vi.set_linewidth(0.0)
ax.errorbar( ax.errorbar(
y, Z[:, k], ERR[:, k], y, Z[:, k], ERR[:, k],
merker="o", linestyle="-", color=cmap_f(k / (len(x) - 1)), marker="o", linestyle="-", color=color,
label=f"{xval:.5f}" label=f"{xval:.5f}"
) )
if len(y) == 1:
ax.text(
y[0], Z[0, k], f" ${Z[0, k]:.5f} \\pm {ERR[0, k]:.5f}$",
color="k", fontsize="x-small",
ha="left", va="center",
)
ax.minorticks_on() ax.minorticks_on()
ax.grid(True, "major") ax.grid(True, "major")
ax.grid(True, "minor", linestyle=":") ax.grid(True, "minor", linestyle=":")
...@@ -513,6 +527,12 @@ def lineplot( ...@@ -513,6 +527,12 @@ def lineplot(
x, yk, errk, x, yk, errk,
marker="o", linestyle="-", color=c marker="o", linestyle="-", color=c
) )
if len(x) == 1:
ax.text(
x[0], yk[0], f" ${yk[0]:.5f} \\pm {errk[0]:.5f}$",
color=c, fontsize="x-small",
ha="left", va="center",
)
if has_pop: if has_pop:
v = ax.violinplot( v = ax.violinplot(
pop.mean(axis=0), x, pop.mean(axis=0), x,
...@@ -522,10 +542,18 @@ def lineplot( ...@@ -522,10 +542,18 @@ def lineplot(
for vi in v["bodies"]: for vi in v["bodies"]:
vi.set_facecolor("k") vi.set_facecolor("k")
vi.set_linewidth(0.0) vi.set_linewidth(0.0)
m = y.mean(axis=0)
e = np.sqrt((err**2).sum(axis=0)) / err.shape[0]
ax.errorbar( ax.errorbar(
x, y.mean(axis=0), np.sqrt((err**2).sum(axis=0)) / err.shape[0], x, m, e,
marker="o", linestyle="-", color="k" marker="o", linestyle="-", color="k"
) )
if len(x) == 1:
ax.text(
x[0], m[0], f" ${m[0]:.5f} \\pm {e[0]:.5f}$",
color="k", fontsize="x-small",
ha="left", va="center",
)
ax.minorticks_on() ax.minorticks_on()
ax.grid(True, "major") ax.grid(True, "major")
ax.grid(True, "minor", linestyle=":") ax.grid(True, "minor", linestyle=":")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment