Something went wrong on our end
freq_acc_plot.py 2.08 KiB
import os
import ujson
import pprint
from collections import namedtuple
from operator import attrgetter
import constants
pp = pprint.PrettyPrinter(indent=4)
if __name__=='__main__':
with open(constants.regions_json) as file:
all_regions_data = ujson.load(file)
with open(constants.genome_train_region_ids) as file:
train_region_ids = ujson.load(file)
obj_pred_json = os.path.join(
constants.region_object_scores_dirname,
'object_predictions.json')
with open(obj_pred_json) as file:
obj_pred_data = ujson.load(file)
obj_train_freq_dict = dict()
obj_test_freq_dict = dict()
obj_hit_dict = dict()
for region_id in train_region_ids:
region_data = all_regions_data[region_id]
for obj in region_data['object_names']:
if obj not in obj_train_freq_dict:
obj_train_freq_dict[obj] = 0
obj_train_freq_dict[obj] += 1
for region_id, pred_dict in obj_pred_data.items():
for obj in pred_dict['gt']:
if obj not in obj_hit_dict:
obj_hit_dict[obj] = 0
if obj in pred_dict['pred']:
obj_hit_dict[obj] += 1
if obj not in obj_test_freq_dict:
obj_test_freq_dict[obj] = 1
else:
obj_test_freq_dict[obj] += 1
AccFreqData = namedtuple('AccFreqData',['label','acc','train_freq','test_freq'])
records = []
for obj, hits in obj_hit_dict.items():
if obj not in obj_train_freq_dict:
obj_train_freq_dict[obj] = 0
records += [AccFreqData(
label=obj,
acc=float(hits)/float(obj_test_freq_dict[obj]),
test_freq=obj_test_freq_dict[obj],
train_freq=obj_train_freq_dict[obj])]
sorted_records = sorted(records, key=attrgetter('train_freq'))
pp.pprint(sorted_records)
# print '{} \t {} \t {}'.format(
# obj,
# hits/float(obj_test_freq_dict[obj]),
# obj_test_freq_dict[obj],
# )
print(len(obj_pred_data))