diff --git a/object_attribute_classifier_cached_features/ratio_acc_plot_obj.py b/object_attribute_classifier_cached_features/ratio_acc_plot_obj.py new file mode 100644 index 0000000000000000000000000000000000000000..954fe91f241b8bd1180047c4dc8424a78f3ef7f8 --- /dev/null +++ b/object_attribute_classifier_cached_features/ratio_acc_plot_obj.py @@ -0,0 +1,157 @@ +import os +import pdb +import ujson +import pprint +import math +from collections import namedtuple +from operator import attrgetter +import numpy as np +import matplotlib +matplotlib.use('pdf') +import matplotlib.pyplot as plt + +import constants + +pp = pprint.PrettyPrinter(indent=4) +if __name__=='__main__': + print 'Reading {}'.format(constants.object_labels_json) + with open(constants.object_labels_json) as file: + object_labels_dict = ujson.load(file) + + freq_in_vqa_filename = os.path.join( + '/home/ssd/VisualGenome/restructured/', + 'object_label_freq_in_vqa_train_subset.json') + print 'Reading {}'.format(freq_in_vqa_filename) + with open(freq_in_vqa_filename) as file: + freq_in_vqa = ujson.load(file) + + print 'Reading {}'.format(constants.eval_regions_json) + with open(constants.regions_json) as file: + all_regions_data = ujson.load(file) + + print 'Reading {}'.format(constants.genome_train_region_ids) + with open(constants.genome_train_region_ids) as file: + train_region_ids = ujson.load(file) + + obj_pred_json = os.path.join( + constants.region_object_scores_dirname, + 'object_predictions.json') + + print 'Reading {}'.format(obj_pred_json) + with open(obj_pred_json) as file: + obj_pred_data = ujson.load(file) + + obj_train_freq_dict = dict() + obj_test_freq_dict = dict() + obj_hit_dict = dict() + + for obj in object_labels_dict.keys(): + obj_train_freq_dict[obj] = 0 + obj_test_freq_dict[obj] = 0 + obj_hit_dict[obj] = 0 + + for region_id in train_region_ids: + region_data = all_regions_data[region_id] + for obj in region_data['object_names']: + if obj not in obj_train_freq_dict: + obj_train_freq_dict[constants.unknown_token] += 1 + else: + obj_train_freq_dict[obj] += 1 + + + + for region_id, pred_dict in obj_pred_data.items(): + for obj in pred_dict['gt']: + if obj in pred_dict['pred']: + obj_hit_dict[obj] += 1 + + if obj in obj_test_freq_dict: + obj_test_freq_dict[obj] += 1 + else: + obj_test_freq_dict[constants.unknown_token] += 1 + + + AccFreqData = namedtuple('AccFreqData',['label','acc','hits','train_freq','test_freq','ratio']) + records = [] + for obj, hits in obj_hit_dict.items(): + # if obj not in obj_train_freq_dict: + # obj_train_freq_dict[obj] = 0 + if obj==constants.unknown_token: + continue + + records += [AccFreqData( + label=obj, + acc=float(hits + 1e-5)/float(obj_test_freq_dict[obj] + 1e-5), + hits=hits, + test_freq=obj_test_freq_dict[obj], + train_freq=obj_train_freq_dict[obj], + ratio = np.log(float(train_freq)/(freq_in_vqa[obj]+1e-5)))] + + sorted_records = sorted(records, key=attrgetter('ratio')) + pp.pprint(sorted_records) + pdb.set_trace() + hits_vs_rarity = { + '0': 0, + '100': 0, + '500': 0, + '1000': 0, + '5000': 0, + '10000': 0, + 'many': 0, + } + + rarity_counts = { + '0': 0, + '100': 0, + '500': 0, + '1000': 0, + '5000': 0, + '10000': 0, + 'many': 0, + } + + for record in sorted_records: + if record.train_freq > 10000: + hits_vs_rarity['many'] += record.hits + rarity_counts['many'] += record.test_freq + elif record.train_freq > 5000: + hits_vs_rarity['10000'] += record.hits + rarity_counts['10000'] += record.test_freq + elif record.train_freq > 1000: + hits_vs_rarity['5000'] += record.hits + rarity_counts['5000'] += record.test_freq + elif record.train_freq >500: + hits_vs_rarity['1000'] += record.hits + rarity_counts['1000'] += record.test_freq + elif record.train_freq >100: + hits_vs_rarity['500'] += record.hits + rarity_counts['500'] += record.test_freq + elif record.train_freq >0: + hits_vs_rarity['100'] += record.hits + rarity_counts['100'] += record.test_freq + else: + hits_vs_rarity['0'] += record.hits + rarity_counts['0'] += record.test_freq + + for key, value in hits_vs_rarity.items(): + print '{}: {}'.format( + key, + 100*float(value+1e-5)/(rarity_counts[key]+1e-5)) + + print hits_vs_rarity + + hits_data_filename = os.path.join( + constants.region_object_scores_dirname, + 'object_hits_data.json') + + hits_data = { + 'hits_vs_rarity': hits_vs_rarity, + 'rarity_counts': rarity_counts, + } + with open(hits_data_filename,'w') as file: + ujson.dump(hits_data, file, indent=4) + + + print(len(obj_pred_data)) + +