import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import config
import geopandas as gpd
from shapely.geometry import LineString
import folium


days_of_week = ['Saturday', 'Sunday', 'Monday', 'Tuesday', 'Wednesday', 'Thursday','Friday']


def plot_particular_day(w, day, color = 'g'):
    plt.plot(w[day*24 : 24 + day*24], color)

        
def find_mean_trend_particular_weekday(w, weekday):
    w2 = w.reshape(len(w)//24, 24)[weekday::7]
    return np.mean(w2, axis=0)


def find_extreme_events_particular_sig(W, sig, weekday):
    w = W[:, sig]
    mean_trend = find_mean_trend_particular_weekday(w, weekday)
    w = w.reshape(365, 24)
    w = w[weekday::7]
    
    deviation = np.sum(abs(w - mean_trend), axis = 1)
    extreme_weeks = abs(deviation - np.mean(deviation)) > 1.5*np.std(deviation)
    print(sig, [i for i,j in enumerate(extreme_weeks) if j])
    return None

def find_extreme_events_particular_weekday(W, weekday):
    for sig in range(config.RANK):
        find_extreme_events_particular_sig(W, sig, weekday)

def plot_particular_weekday(w, weekday, mean = False):
    for day in range(weekday, 365, 7):
        plot_particular_day(w, day)
    if mean:
        plt.plot(find_mean_trend_particular_weekday(w, weekday), 'b-')

        
def plot_all_weekdays(w, mean = False, signature = '??'):
    plt.ioff()
    fig = plt.figure(figsize = (16, 9))  # size in inches
    for weekday in range(0, 7):
        plot_particular_weekday(w, weekday, mean = mean)
        imagename = './Images/Signature_Trends/' + str(signature).zfill(2) + days_of_week[weekday][0:3] + '.png'
        plt.savefig(imagename, dpi=300, bbox_inches='tight')  # dpi ideal for viewing on laptop fullscreen
        plt.clf()

        
def plot_all_signatures_particular_weekday(W, weekday, mean = False):
    plt.ioff()
    fig = plt.figure(figsize = (16, 9))  # size in inches        
    for signature in range(config.RANK):
        plot_particular_weekday(W[:, signature], weekday, mean = mean)
        imagename = './Images/Signature_Trends/' + str(signature).zfill(2) + days_of_week[weekday][0:3] + '.png'
        plt.savefig(imagename, dpi=300, bbox_inches='tight')  # dpi ideal for viewing on laptop fullscreen
        plt.clf()
        
        
def Heatmap(W1, W2):
    assert W1.shape == W2.shape
    Ro = np.corrcoef(W1.T,W2.T)[len(W1.T):,:len(W2.T)]
    plt.imshow(Ro, cmap='hot', interpolation='nearest')
    plt.colorbar()
    plt.title('Heatmap for comparing two runs of W')
    plt.xlabel('Columns of W1')
    plt.ylabel('Columns of W2')
    plt.savefig('Heatmap')
    plt.show()
    
    return None    


# Requires full_link_ids.txt, links.csv, and H_trips.txt in running folder
# Requires Signature_Maps folder within running folder
def map_signature(signature):
    filenames = config.generate_filenames('')
    
    full_link_ids = np.loadtxt('full_link_ids.txt', dtype=int)
    H_loaded = gpd.pd.read_csv(filenames['H_trips'], delimiter=' ', names=full_link_ids)
    link_df = gpd.pd.read_csv(filenames['links'])
    
    crs = {'init': 'epsg:4326'}
    
    signature_list = H_loaded.loc[signature,:][H_loaded.loc[signature,:]>0].keys().tolist()
    
    if len(signature_list)>0:
        df = link_df[link_df['link_id'].isin(signature_list)]
        df['geom'] = df.apply(lambda x: LineString([(x.startX,x.startY),(x.endX,x.endY)]), axis=1)
        geo_df = gpd.GeoDataFrame(df, crs=crs, geometry='geom')
        location =  './Images/Signature_Maps/sig' + format(signature, '02d')
        geo_df.to_file(location)
    else:
        print('Signature'+str(signature)+'has no links')
       
    return None

def map_signature_folium(signature):
    m = folium.Map(width=700,location=[40.7529595,-73.9723327], zoom_start=12)
    filenames = config.generate_filenames('')
    full_link_ids = np.loadtxt('full_link_ids.txt', dtype=int)
    H_loaded = gpd.pd.read_csv(filenames['H_trips'], delimiter=' ', names=full_link_ids)
    link_df = gpd.pd.read_csv(filenames['links'])

    signature_list = H_loaded.iloc[signature,:][H_loaded.iloc[signature,:]>0].keys().tolist()

    df = link_df[link_df['link_id'].isin(signature_list)]
    df['geom'] = df.apply(lambda x: [[x.startY,x.startX],[x.endY,x.endX]], axis=1)
    links = df.geom.tolist()
    folium.PolyLine(links).add_to(m)
    m.fit_bounds([(40.699136, -74.028113),(40.803235, -73.923239)])
    folium.TileLayer('cartodbpositron').add_to(m)
    location =  './Images/Signature_Maps/sig' + format(signature, '02d')
    folium.Map.save(m, location+'.html')
    
    return None


def map_all_signatures_folium():
    filenames = config.generate_filenames('')
    full_link_ids = np.loadtxt('full_link_ids.txt', dtype=int)
    H_loaded = gpd.pd.read_csv(filenames['H_trips'], delimiter=' ', names=full_link_ids)
    link_df = gpd.pd.read_csv(filenames['links'])
    
    for signature in range(config.RANK):
        signature_list = H_loaded.loc[signature,:][H_loaded.loc[signature,:]>0].keys().tolist()
        if len(signature_list)>0:
            m = folium.Map(width=700,location=[40.7529595,-73.9723327], zoom_start=12)
            df = link_df[link_df['link_id'].isin(signature_list)]
            df['geom'] = df.apply(lambda x: [[x.startY,x.startX],[x.endY,x.endX]], axis=1)
            links = df.geom.tolist()
            folium.PolyLine(links).add_to(m)
            m.fit_bounds([(40.699136, -74.028113),(40.803235, -73.923239)])
            folium.TileLayer('cartodbpositron').add_to(m)
            location =  './Images/Signature_Maps/sig' + format(signature, '02d')
            folium.Map.save(m, location + '.html')
    return None


def map_all_signatures():
    filenames = config.generate_filenames('')
    
    full_link_ids = np.loadtxt('full_link_ids.txt', dtype=int)
    H_loaded = gpd.pd.read_csv(filenames['H_trips'], delimiter=' ', names=full_link_ids)
    link_df = gpd.pd.read_csv(filenames['links'])
    
    crs = {'init': 'epsg:4326'}
    
    signature_link_lists = list()
    link_df = gpd.pd.read_csv(filenames['links'])

    for signature in range(config.RANK):
        signature_list = H_loaded.loc[signature,:][H_loaded.loc[signature,:]>0].keys().tolist()
        signature_link_lists.append(signature_list)
        if len(signature_list)>0:
            df = link_df[link_df['link_id'].isin(signature_link_lists[signature])]
            df['geom'] = df.apply(lambda x: LineString([(x.startX,x.startY),(x.endX,x.endY)]), axis=1)
            geo_df = gpd.GeoDataFrame(df, crs=crs, geometry='geom')
            location =  './Images/Signature_Maps/sig' + format(signature, '02d')
            geo_df.to_file(location)
        else:
            print('Signature'+str(signature)+'has no links')
       
    return None


def map_links_with_many_signatures(threshold=7): # threshold for number of signatures for a link considered high
    filenames = config.generate_filenames('')
    
    full_link_ids = np.loadtxt('full_link_ids.txt', dtype=int)
    H_loaded = gpd.pd.read_csv(filenames['H_trips'], delimiter=' ', names=full_link_ids)
    link_df = gpd.pd.read_csv(filenames['links'])
    
    crs = {'init': 'epsg:4326'}
    
    ds = np.sum(H_loaded>0, axis=0)
    link_ids_with_high_signature_count = ds[ds>=threshold].keys().tolist()
    
    df = link_df[link_df['link_id'].isin(link_ids_with_high_signature_count)]
    df['geom'] = df.apply(lambda x: LineString([(x.startX,x.startY),(x.endX,x.endY)]), axis=1)
    geo_df = gpd.GeoDataFrame(df, crs=crs, geometry='geom')
    location =  './Images/'
    geo_df.to_file(location + 'High_Signature_Map')
    
    return None
        

def map_links(link_list=None, link_file=None, out_file='Map'):
    assert link_list is None or link_file is None, 'Choose exactly one list or file to convert to map!'
    if link_list is not None:
        link_id_list = link_list
    else:
        link_id_list=np.loadtxt(link_file)

    filenames = config.generate_filenames('')
    link_df = gpd.pd.read_csv(filenames['links'])
    
    crs = {'init': 'epsg:4326'}
    
    df = link_df[link_df['link_id'].isin(link_id_list)]
    df['geom'] = df.apply(lambda x: LineString([(x.startX,x.startY),(x.endX,x.endY)]), axis=1)
    geo_df = gpd.GeoDataFrame(df, crs=crs, geometry='geom')
    location =  './Images/'
    geo_df.to_file(location + out_file)
    
    return None


def map_links_folium(link_list=None, link_file=None, out_file='Map'):
    assert link_list is None or link_file is None, 'Choose exactly one list or file to convert to map!'
    if link_list is not None:
        link_id_list = link_list
    else:
        link_id_list=np.loadtxt(link_file)

    filenames = config.generate_filenames('')
    link_df = gpd.pd.read_csv(filenames['links'])
    
    df = link_df[link_df['link_id'].isin(link_id_list)]
        
    m = folium.Map(width=700,location=[40.7529595,-73.9723327], zoom_start=12)
    df['geom'] = df.apply(lambda x: [[x.startY,x.startX],[x.endY,x.endX]], axis=1)
    links = df.geom.tolist()
    folium.PolyLine(links).add_to(m)
    m.fit_bounds([(40.699136, -74.028113),(40.803235, -73.923239)])
    folium.TileLayer('cartodbpositron').add_to(m)
    location =  './Images/'
    folium.Map.save(m, location + 'Map.html')
    
    return None


def correlation_check(W1, W2, threshold = 0.8):
    assert W1.shape == W2.shape
    assert threshold <= 1 and threshold>= -1, "Unreasonable threshold. Must be between -1 and 1"
    
    Ro = np.corrcoef(W1.T,W2.T)[len(W1.T):,:len(W2.T)]
    Ro_geq_threshold = Ro > threshold
    
    print('Possible Permutation! ith column of W1 can be mapped to jth column of W2 appears as j is in possible_maps[i]')
    possible_maps = [np.where(column >0) for column in Ro_geq_threshold.T]
    return Ro, possible_maps



def find_permutation(W1, W2):
    '''Find permutation that gives a pretty-looking heatmap.'''
    sigs = list(range(config.RANK))
    SIGS = list(range(config.RANK))
    permutation = []

    for threshold in np.arange(0.9, -0.01, -0.01):
        Ro, possible_maps = correlation_check(W1, W2, threshold)
        possible_maps = [i[0] for i in possible_maps]
        possible_maps = [list(i) for i in possible_maps]
    
        for sig in sigs:
            for SIG in possible_maps[sig]:
                try:
                    SIGS.remove(SIG)
                    sigs.remove(sig)
                    permutation.append((sig,SIG))
                    break
                except:
                    continue
    for i in range(len(sigs)):
        permutation.append((sigs[i], SIGS[i]))
    return permutation, sigs, SIGS


def permute_and_sort(W1, H1, W2, permute = False):
    '''If permute is True, finds a permutation, sorts W1, H1 accordingly, and makes a Heatmap of sorted(W1) vs. W2.
    If permute is False, just makes Heatmap of W1 vs. W2'''
    def permute_WH(sigs, coeffs, permutation):
        sigs2 = sigs.copy()
        coeffs2 = coeffs.copy()
        for i in range(len(permutation)):
            sigs2[:, permutation[i][1]] = sigs[:, permutation[i][0]]
            coeffs2[permutation[i][1]] = coeffs[permutation[i][0]]
        return sigs2, coeffs2

    if permute:
        permutation, sigs, SIGS = find_permutation(W1, W2)
    else:
        plt.figure(figsize = (10,9), dpi=200)
        Heatmap(W1, W2)
        return W1, H1, W2, []
    
    w1, h1 = permute_WH(W1, H1, permutation)

    plt.figure(figsize = (10,10), dpi=200)
    Heatmap(w1, W2)
    return w1, h1, W2, permutation
    
    
def spikeyness_vs_error():
    D = np.loadtxt('D_trips.txt')
    W = np.loadtxt('W_trips.txt')
    H = np.loadtxt('H_trips.txt')
    def spikeyness(link):
        trend = D[:52*24*7, link].reshape(52,24*7)
        return np.mean(np.nanstd(trend, axis=0))/np.nanmean(D[:,link])
    
    def errors():
        return np.linalg.norm(np.nan_to_num(D - W@H), axis=0)/np.linalg.norm(np.nan_to_num(D), axis=0)
    
    plt.figure(figsize=(10,10))
    plt.plot([spikeyness(link) for link in range(2302)], errors(), '+')
    plt.xlabel('Deviation from weekly periodicity')
    plt.ylabel('Error per link')
    plt.savefig('Spikeyness_vs_Error')
    return None