from matplotlib import pyplot as plt
from collections import defaultdict
import os


files = [f for f in os.listdir('.') if os.path.isfile(f)]
files = [f for f in files if f.endswith('.txt')]
db_name= files[0].split('_')[0]

d = defaultdict(lambda:[0,0])
dr_latency = {10:[],50:[],100:[]}
dr_latency_per_vm = defaultdict(lambda:0)
du_latency = {10:[],50:[],100:[]}
no_of_vms = 0
for file in files:
    if('load' in file):
        continue
    f = open(file, 'r')

    vm=file.split('_')[-1].split('.')[0]
    vm=int(vm)
    no_of_vms = max(no_of_vms, vm)
    thread_count=file.split('_')[-2]
    thread_count=int(thread_count)
    data = f.read()
    lines = data.split('\n')

    for line in lines:
        if("Throughput" in line):
            throughput=float(line.split(',')[-1].strip())
            d[(vm,thread_count)][0]+=throughput
            d[(vm,thread_count)][1]+=thread_count
        if("READ" in line and "AverageLatency" in line):
            latency=float(line.split(',')[-1].strip())
            x=thread_count*vm
            if(x<=10):
                key=10
            elif(x<=50):
                key=50
            elif(x<=100):
                key=100
            else:
                key=500
            
            dr_latency[key].append([vm,latency])
            dr_latency_per_vm[(key,vm)]+=latency


avg_latency_per_vm={}
for i in dr_latency_per_vm:
    total = dr_latency_per_vm[i]
    avg_latency_per_vm[i]=total/i[1]

final_dict2={"Threads":{}}
thread_dict2=final_dict2["Threads"]
for i in sorted(avg_latency_per_vm.keys()):
    threads=i[0]
    vm=i[1]
    avg_latency=avg_latency_per_vm[i]
    if threads not in thread_dict2:
        thread_dict2[threads]={"VMs":[],"Latency":[]}
    thread_dict2[threads]["VMs"].append(vm)
    thread_dict2[threads]["Latency"].append(avg_latency)
    
data = final_dict2
threads = []
vms = []
latencies = []

for thread, vms_latencies in data['Threads'].items():
    for i, vm in enumerate(vms_latencies['VMs']):
        threads.append(thread)
        vms.append(vm)
        latencies.append(vms_latencies['Latency'][i])

fig, ax = plt.subplots()
ax.bar([f'{vm} \n{thread}' for vm, thread in zip(vms, threads)], latencies)
ax.set_ylabel('Latency')
ax.set_xlabel('Threads grouped by No of VMs')
plt.savefig(db_name+'_Latency.png')


tup=[]
for i in d:
    tup.append([d[i][1],i[0],d[i][0]])
tup.sort(key=lambda x:(x[0],x[1]))

keys=[10,50,100]
final_dict={}

for _ in keys:
    final_dict[_]=[]
    for i in range(no_of_vms):
        if len(tup)==0:
            break
        tmp=tup.pop(0)
        final_dict[_].append([tmp[1],tmp[2]])
    final_dict[_].sort()
    
final_dict2={"Threads":{}}
thread_dict2=final_dict2["Threads"]
for threads in final_dict:
    thread_dict2[threads]={"VMs":[],"Throughput":[]}
    for vms in final_dict[threads]:
        thread_dict2[threads]["VMs"].append(vms[0])
        thread_dict2[threads]["Throughput"].append(vms[1])



data = final_dict2
threads = sorted(data['Threads'].keys())
vms = sorted(set(data['Threads'][threads[0]]['VMs']))
throughputs = {vm: [data['Threads'][thread]['Throughput'][vm-1] for thread in threads] for vm in vms}

fig, ax = plt.subplots()
bar_width = 0.15
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']

for i, vm in enumerate(vms):
    x = [j + i * bar_width for j in range(len(threads))]
    ax.bar(x, throughputs[vm], bar_width, color=colors[i], label=f'{vm} VMs')

ax.set_xlabel('Number of threads')
ax.set_ylabel('Throughput')
ax.set_xticks([i + 2 * bar_width for i in range(len(threads))])
ax.set_xticklabels(threads)
ax.legend()


plt.savefig(db_name+'_throughput.png')