import numpy as np
import scipy.signal as signal_power
import scipy as sp
import matplotlib.pyplot as plt
import itertools
from CalibPackage import *


N=128
points=np.arange(N)/N
time=points*2*np.pi
dft_mat=sp.linalg.dft(len(time))
freq=points/points[1]
im_dft_mat=np.imag(dft_mat)
re_dft_mat=np.real(dft_mat)
data=np.cos(30*time)+1j*np.sin(30*time)

skewRange=np.arange(-0.3,0.3,0.02)
BpCellRange=[1,2,3]
Optskew=[0,0,0]
Maxbits=12
ModelNos=np.arange(10)
varRange=np.zeros([len(ModelNos)])
ModelNos=[3,5,8]
SNR_NoComp=[]
SNR_SWIPE=[]
SNRQuant=[]
BitRange=[]
SNRin=100;
im_dft_mat1=im_dft_mat*1
re_dft_mat1=re_dft_mat*1
idealOut=np.matmul(re_dft_mat,data)+1j*np.matmul(im_dft_mat,data)
Testbits=range(Maxbits*10+1)
Testbits=[7]
h=-1
for BpCell in BpCellRange:
    h=h+1
    #NRange=int((Maxbits)/BpCell)
    NRange=10
    l=0
    Noise_SWIPE=np.zeros([len(varRange),len(skewRange),NRange])
    Noise_NoComp=np.zeros([len(varRange),len(skewRange),NRange])
    NoiseQuant=np.zeros([len(varRange),len(skewRange),NRange])
    SigPower=np.zeros([len(varRange),len(skewRange),NRange])
    for ModelNo in ModelNos:
        k=0
        varRange[l]=np.max(getVarModel(ModelNo,BpCell))
        Var=varRange[l]
        for skew in skewRange:
                for m in range(1,NRange+1):
                    Nbits=BpCell*m+1
                    if Nbits in Testbits:
                        for iterations in range(100):
                            nData=AWGN(data,SNRin)
                            im_dft_mat=im_dft_mat1*(1)
                            re_dft_mat=re_dft_mat1*(1)
                            idealOut=np.matmul(re_dft_mat,data)+1j*np.matmul(im_dft_mat,data)
                            im_dft_SWIPE=SWIPE(im_dft_mat, Nbits, 0.7,  skew=skew, BpCell=BpCell,ModelNo=ModelNo)
                            re_dft_SWIPE=SWIPE(re_dft_mat, Nbits, 0.7, skew=skew, BpCell=BpCell,ModelNo=ModelNo)
                            ReRAM_Output_SWIPE=np.matmul(re_dft_SWIPE,nData)+1j*np.matmul(im_dft_SWIPE,nData)
                            Diff_SWIPE=idealOut-ReRAM_Output_SWIPE
                            Noise_SWIPE[l,k,m-1]+=np.mean(np.abs(np.power(Diff_SWIPE,2)))


                            im_dft_NoComp=ConvRead(im_dft_mat,Nbits,BpCell=BpCell,ModelNo=ModelNo)
                            re_dft_NoComp=ConvRead(re_dft_mat,Nbits,BpCell=BpCell,ModelNo=ModelNo)
                            ReRAM_Output_NoComp=np.matmul(re_dft_NoComp,nData)+1j*np.matmul(im_dft_NoComp,nData)
                            Diff_NoComp=idealOut-ReRAM_Output_NoComp
                            Noise_NoComp[l,k,m-1]+=np.mean(np.abs(np.power(Diff_NoComp,2)))

                            im_dft_quant=QuantRead(im_dft_mat,Nbits)
                            re_dft_quant=QuantRead(re_dft_mat,Nbits)
                            QuantOut=np.matmul(re_dft_quant,nData)+1j*np.matmul(im_dft_quant,nData)
                            DiffQuant=idealOut-QuantOut
                            NoiseQuant[l,k,m-1]+=np.mean(np.abs(np.power(DiffQuant,2)))

                            SigPower[l,k,m-1]+=np.mean(np.abs(np.power(idealOut,2)))
                        print('Nbits={}'.format(Nbits))
                        print('SNR No Compensation {}'.format(10*np.log10(SigPower[l,k,m-1]/Noise_NoComp[l,k,m-1])))
                        print('SNR Quantization {}'.format(10*np.log10(SigPower[l,k,m-1]/NoiseQuant[l,k,m-1])))
                        print('SNR Compensation {}'.format(10*np.log10(SigPower[l,k,m-1]/Noise_SWIPE[l,k,m-1])))
                k=k+1
                print('------------ skew={} -----------'.format(skew))
        l=l+1
        print('------------ var={} -----------'.format(Var))
    print('############### BpCell={} ###############'.format(BpCell))
    SNR_NoComp.append(10*np.log10(SigPower/Noise_NoComp))
    SNR_SWIPE.append(10*np.log10(SigPower/Noise_SWIPE))
    SNRQuant.append(10*np.log10(SigPower/NoiseQuant))
    BitRange.append(np.arange(1,NRange+1))


io.savemat('SkewData.mat',{'SNRConv':SNR_NoComp,'SNRSWIPE':SNR_SWIPE,'SNRQuant':SNRQuant,'BitRange':BitRange})