import scipy.signal as scisig
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import os
from mpl_toolkits.axes_grid1 import make_axes_locatable


file = np.load("data/waveform_data.npz", allow_pickle=True)
signal = file['signal']
wfm = file['wfm'].item()
sr = wfm.sample_rate

f, t, Sxx = scisig.stft(signal, fs=sr, nperseg=256*100)
f /= 1e6
t *= 1e3
Sxx[abs(Sxx) < 0.01] = 0

fig, ax = plt.subplots()
im = ax.pcolormesh(t, f, np.abs(Sxx), shading='gouraud')
plt.title("Signal Spectrogram Frequency")
plt.ylabel('Frequency [MHz]')
plt.xlabel('Time [ms]')
plt.ylim(95,115)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im, cax=cax)
if True:
    plt.savefig("data/Spectrogram-frequency.png", dpi=1200)

fig, ax = plt.subplots()
im = ax.pcolormesh(t, f, np.angle(Sxx), shading='gouraud')
plt.title("Signal Spectrogram Phase")
plt.ylabel('Frequency [MHz]')
plt.xlabel('Time [ms]')
plt.ylim(95,115)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im, cax=cax)
if True:
    plt.savefig("data/Spectrogram-phase.png", dpi=1200)