From f88c7cd9c8ed86b8c41712d722aecb7b33456628 Mon Sep 17 00:00:00 2001 From: tgupta6 <tgupta6@illinois.edu> Date: Sat, 8 Oct 2016 14:20:52 -0500 Subject: [PATCH] add_hypernyms fun in visual_genome_parser --- visual_genome_parser.py | 40 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/visual_genome_parser.py b/visual_genome_parser.py index 1f30a3e..ed6bcde 100644 --- a/visual_genome_parser.py +++ b/visual_genome_parser.py @@ -29,6 +29,8 @@ _raw_attribute_labels = 'raw_attribute_labels.json' _object_labels = 'object_labels.json' _attribute_labels = 'attribute_labels.json' _regions_with_labels = 'region_with_labels.json' +_regions_with_hypernym_labels = 'region_with_hypernym_labels.json' +_object_labels_to_hypernyms = 'object_labels_to_hypernyms.json' _unknown_token = 'UNK' _unopenable_images = 'unopenable_images.json' _vocab = 'vocab.json' @@ -528,6 +530,37 @@ def create_genome_to_vqa_map(): return map +def add_hypernyms( + in_regions_json, + out_regions_json, + labels_to_hypernyms_json): + + in_regions_json_filename = os.path.join( + _outdir, in_regions_json) + print 'Reading file {} ...'.format(in_regions_json_filename) + with open(in_regions_json_filename,'r') as file: + regions_data = ujson.load(file) + + labels_to_hypernyms_json_filename = os.path.join( + _outdir, labels_to_hypernyms_json) + print 'Reading file {} ...'.format(labels_to_hypernyms_json_filename) + with open(labels_to_hypernyms_json_filename,'r') as file: + hypernyms = ujson.load(file) + print len(hypernyms) + + for region_id, region_info in regions_data.items(): + object_names = set(region_info['object_names']) + for label in region_info['object_names']: + if label in hypernyms: + object_names.update(hypernyms[label]) + regions_data[region_id]['object_names'] = list(object_names) + + out_regions_json_filename = os.path.join( + _outdir, out_regions_json) + print 'Writing file {} ...'.format(out_regions_json_filename) + with open(out_regions_json_filename,'w') as file: + ujson.dump(regions_data,file,sort_keys=True,indent=4) + def partition_region_ids( vqa_train_subset_qids_json, @@ -668,7 +701,10 @@ if __name__=='__main__': # _vqa_train_anno, # _vqa_val_anno, # _genome_train_subset_region_ids) - - generate_md5hash() + add_hypernyms( + _regions_with_labels, + _regions_with_hypernym_labels, + _object_labels_to_hypernyms) + # generate_md5hash() -- GitLab