Skip to content
Snippets Groups Projects
cmp_obj_freq_acc.py 2.16 KiB
import os
import ujson
import matplotlib
matplotlib.use('pdf')
import matplotlib.pyplot as plt

def cmp_hist(path, bin_order):
    with open(path,'r') as file:
        hits_data = ujson.load(file)

    hist = []
    counts = []
    hits_list = []
    acc = 0.0
    total = 0.0
    for bin in bin_order:
        hits = hits_data['hits_vs_rarity'][bin]
        bin_count = hits_data['rarity_counts'][bin]
        acc += hits
        total += bin_count
        hist.append(100*(hits + 1e-5)/(bin_count + 1e-5))
        counts.append(bin_count)
        hits_list.append(hits)
        
    print acc/total, total
    return hist, counts, hits_list
        

paths_object_hits_data = {
    'obj_acc_with_ans': '/data/tanmay/GenVQA_Exp_Results/obj_atr_through_ans_single_feat/object_attribute_classifiers/object_scores/object_hits_data.json',
    'obj_acc_wo_ans': '/data/tanmay/GenVQA_Exp_Results/obj_atr_through_none_single_feat/object_attribute_classifiers/object_scores/object_hits_data.json',
}

# paths_object_hits_data = {
#     'obj_acc_with_ans': '/data/tanmay/GenVQA_Exp_Results/obj_atr_through_ans_single_feat/object_attribute_classifiers/attribute_scores/attribute_hits_data.json',
#     'obj_acc_wo_ans': '/data/tanmay/GenVQA_Exp_Results/obj_atr_through_none_single_feat/object_attribute_classifiers/attribute_scores/attribute_hits_data.json',
# }

bin_order = ['0','100','500','1000','5000','10000','many']

hist = dict()
for plot_num, (name, path) in enumerate(paths_object_hits_data.items()):
    hist[name], counts, hits = cmp_hist(path, bin_order)
    print hist[name], hits, counts


# data=dict()
# for name, path in plots_to_compare.items():
#     with open(path,'r') as file:
#         data[name] = ujson.load(file)

# diff = []
# for acc1, acc2 in zip(data['obj_atr_through_none_det']['acc'],
#                       data['obj_atr_through_ans_det']['acc']):

#     diff.append(acc2-acc1)

# freq = data['obj_atr_through_ans_det']['freq']
# print len(freq), len(diff), len(data['obj_atr_through_ans_det']['freq'])
# plt.plot(freq,diff)

# plot_filename = os.path.join(
#     '/data/tanmay/GenVQA_Exp_Results/random_results',
#     'diff_acc_freq.pdf')

# plt.savefig(plot_filename)