diff --git a/object_attribute_classifier_cached_features/freq_acc_plot.py b/object_attribute_classifier_cached_features/freq_acc_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..0e4203af4f37a46cbd089f6c4babbc95f7fcf330 --- /dev/null +++ b/object_attribute_classifier_cached_features/freq_acc_plot.py @@ -0,0 +1,74 @@ +import os +import ujson +import pprint +from collections import namedtuple +from operator import attrgetter + +import constants + +pp = pprint.PrettyPrinter(indent=4) +if __name__=='__main__': + with open(constants.regions_json) as file: + all_regions_data = ujson.load(file) + + with open(constants.genome_train_region_ids) as file: + train_region_ids = ujson.load(file) + + obj_pred_json = os.path.join( + constants.region_object_scores_dirname, + 'object_predictions.json') + + with open(obj_pred_json) as file: + obj_pred_data = ujson.load(file) + + obj_train_freq_dict = dict() + obj_test_freq_dict = dict() + obj_hit_dict = dict() + + for region_id in train_region_ids: + region_data = all_regions_data[region_id] + for obj in region_data['object_names']: + if obj not in obj_train_freq_dict: + obj_train_freq_dict[obj] = 0 + + obj_train_freq_dict[obj] += 1 + + + + for region_id, pred_dict in obj_pred_data.items(): + for obj in pred_dict['gt']: + if obj not in obj_hit_dict: + obj_hit_dict[obj] = 0 + + if obj in pred_dict['pred']: + obj_hit_dict[obj] += 1 + + if obj not in obj_test_freq_dict: + obj_test_freq_dict[obj] = 1 + else: + obj_test_freq_dict[obj] += 1 + + + AccFreqData = namedtuple('AccFreqData',['label','acc','train_freq','test_freq']) + records = [] + for obj, hits in obj_hit_dict.items(): + if obj not in obj_train_freq_dict: + obj_train_freq_dict[obj] = 0 + + records += [AccFreqData( + label=obj, + acc=float(hits)/float(obj_test_freq_dict[obj]), + test_freq=obj_test_freq_dict[obj], + train_freq=obj_train_freq_dict[obj])] + + sorted_records = sorted(records, key=attrgetter('train_freq')) + pp.pprint(sorted_records) + # print '{} \t {} \t {}'.format( + # obj, + # hits/float(obj_test_freq_dict[obj]), + # obj_test_freq_dict[obj], + # ) + + print(len(obj_pred_data)) + +