Skip to content
Snippets Groups Projects
freq_acc_plot.py 2.08 KiB
import os
import ujson
import pprint
from collections import namedtuple
from operator import attrgetter

import constants

pp = pprint.PrettyPrinter(indent=4)
if __name__=='__main__':
    with open(constants.regions_json) as file:
        all_regions_data = ujson.load(file)

    with open(constants.genome_train_region_ids) as file:
        train_region_ids = ujson.load(file)

    obj_pred_json = os.path.join(
        constants.region_object_scores_dirname,
        'object_predictions.json')
    
    with open(obj_pred_json) as file:
        obj_pred_data = ujson.load(file)

    obj_train_freq_dict = dict()
    obj_test_freq_dict = dict()
    obj_hit_dict = dict()
    
    for region_id in train_region_ids:
        region_data = all_regions_data[region_id]
        for obj in region_data['object_names']:
            if obj not in obj_train_freq_dict:
                obj_train_freq_dict[obj] = 0

            obj_train_freq_dict[obj] += 1



    for region_id, pred_dict in obj_pred_data.items():
        for obj in pred_dict['gt']:
            if obj not in obj_hit_dict:
                obj_hit_dict[obj] = 0
            
            if obj in pred_dict['pred']:
                obj_hit_dict[obj] += 1

            if obj not in obj_test_freq_dict:
                obj_test_freq_dict[obj] = 1
            else:
                obj_test_freq_dict[obj] += 1


    AccFreqData = namedtuple('AccFreqData',['label','acc','train_freq','test_freq'])
    records = []
    for obj, hits in obj_hit_dict.items():
        if obj not in obj_train_freq_dict:
            obj_train_freq_dict[obj] = 0

        records += [AccFreqData(
            label=obj, 
            acc=float(hits)/float(obj_test_freq_dict[obj]), 
            test_freq=obj_test_freq_dict[obj],
            train_freq=obj_train_freq_dict[obj])]

    sorted_records = sorted(records, key=attrgetter('train_freq'))
    pp.pprint(sorted_records)
        # print '{} \t {} \t {}'.format(
        #     obj,
        #     hits/float(obj_test_freq_dict[obj]),
        #     obj_test_freq_dict[obj],
        # )

    print(len(obj_pred_data))