Skip to content
Snippets Groups Projects
freq_acc_plot.py 2.96 KiB
import os
import ujson
import pprint
import math
from collections import namedtuple
from operator import attrgetter
import numpy as np
import matplotlib
matplotlib.use('pdf')
import matplotlib.pyplot as plt

import constants

pp = pprint.PrettyPrinter(indent=4)
if __name__=='__main__':
    print 'Reading {}'.format(constants.regions_json)
    with open(constants.regions_json) as file:
        all_regions_data = ujson.load(file)

    print 'Reading {}'.format(constants.genome_train_region_ids)
    with open(constants.genome_train_region_ids) as file:
        train_region_ids = ujson.load(file)

    obj_pred_json = os.path.join(
        constants.region_object_scores_dirname,
        'object_predictions.json')
    
    print 'Reading {}'.format(obj_pred_json)
    with open(obj_pred_json) as file:
        obj_pred_data = ujson.load(file)

    obj_train_freq_dict = dict()
    obj_test_freq_dict = dict()
    obj_hit_dict = dict()
    
    for region_id in train_region_ids:
        region_data = all_regions_data[region_id]
        for obj in region_data['object_names']:
            if obj not in obj_train_freq_dict:
                obj_train_freq_dict[obj] = 0

            obj_train_freq_dict[obj] += 1



    for region_id, pred_dict in obj_pred_data.items():
        for obj in pred_dict['gt']:
            if obj not in obj_hit_dict:
                obj_hit_dict[obj] = 0
            
            if obj in pred_dict['pred']:
                obj_hit_dict[obj] += 1

            if obj not in obj_test_freq_dict:
                obj_test_freq_dict[obj] = 1
            else:
                obj_test_freq_dict[obj] += 1


    AccFreqData = namedtuple('AccFreqData',['label','acc','train_freq','test_freq'])
    records = []
    for obj, hits in obj_hit_dict.items():
        if obj not in obj_train_freq_dict:
            obj_train_freq_dict[obj] = 0

        records += [AccFreqData(
            label=obj, 
            acc=float(hits)/float(obj_test_freq_dict[obj]), 
            test_freq=obj_test_freq_dict[obj],
            train_freq=obj_train_freq_dict[obj])]

    sorted_records = sorted(records, key=attrgetter('train_freq'))
    pp.pprint(sorted_records)
        # print '{} \t {} \t {}'.format(
        #     obj,
        #     hits/float(obj_test_freq_dict[obj]),
        #     obj_test_freq_dict[obj],
        # )

    freq = []
    acc = []
    for record in sorted_records:
        freq.append(record.train_freq)
        acc.append(record.acc)
        
    acc_by_freq = {
        'freq': freq,
        'acc': acc,
    }
    print len(freq)
    acc_by_freq_filename = os.path.join(
        constants.region_object_scores_dirname,
        'acc_by_freq.json')
    with open(acc_by_freq_filename,'w') as file:
        ujson.dump(acc_by_freq,file)

    plt.plot(freq[200:800],acc[200:800])
    plot_filename = os.path.join(
        constants.region_object_scores_dirname,
        'obj_acc_vs_freq.pdf')
    plt.savefig(plot_filename)
    
    
        
    print(len(obj_pred_data))