From f88c7cd9c8ed86b8c41712d722aecb7b33456628 Mon Sep 17 00:00:00 2001
From: tgupta6 <tgupta6@illinois.edu>
Date: Sat, 8 Oct 2016 14:20:52 -0500
Subject: [PATCH] add_hypernyms fun in visual_genome_parser

---
 visual_genome_parser.py | 40 ++++++++++++++++++++++++++++++++++++++--
 1 file changed, 38 insertions(+), 2 deletions(-)

diff --git a/visual_genome_parser.py b/visual_genome_parser.py
index 1f30a3e..ed6bcde 100644
--- a/visual_genome_parser.py
+++ b/visual_genome_parser.py
@@ -29,6 +29,8 @@ _raw_attribute_labels = 'raw_attribute_labels.json'
 _object_labels = 'object_labels.json'
 _attribute_labels = 'attribute_labels.json'
 _regions_with_labels = 'region_with_labels.json'
+_regions_with_hypernym_labels = 'region_with_hypernym_labels.json'
+_object_labels_to_hypernyms = 'object_labels_to_hypernyms.json'
 _unknown_token = 'UNK'
 _unopenable_images = 'unopenable_images.json'
 _vocab = 'vocab.json'
@@ -528,6 +530,37 @@ def create_genome_to_vqa_map():
         
     return map
 
+def add_hypernyms(
+        in_regions_json,
+        out_regions_json,
+        labels_to_hypernyms_json):
+
+    in_regions_json_filename = os.path.join(
+        _outdir, in_regions_json)
+    print 'Reading file {} ...'.format(in_regions_json_filename)
+    with open(in_regions_json_filename,'r') as file:
+        regions_data = ujson.load(file)
+
+    labels_to_hypernyms_json_filename = os.path.join(
+        _outdir, labels_to_hypernyms_json)
+    print 'Reading file {} ...'.format(labels_to_hypernyms_json_filename)
+    with open(labels_to_hypernyms_json_filename,'r') as file:
+        hypernyms = ujson.load(file)
+    print len(hypernyms)
+
+    for region_id, region_info in regions_data.items():
+        object_names = set(region_info['object_names'])
+        for label in region_info['object_names']:
+            if label in hypernyms:
+                object_names.update(hypernyms[label])
+        regions_data[region_id]['object_names'] = list(object_names)
+
+    out_regions_json_filename = os.path.join(
+        _outdir, out_regions_json)
+    print 'Writing file {} ...'.format(out_regions_json_filename)
+    with open(out_regions_json_filename,'w') as file:
+        ujson.dump(regions_data,file,sort_keys=True,indent=4)
+        
 
 def partition_region_ids(
         vqa_train_subset_qids_json,
@@ -668,7 +701,10 @@ if __name__=='__main__':
     #     _vqa_train_anno,
     #     _vqa_val_anno,
     #     _genome_train_subset_region_ids)
-
-    generate_md5hash()
+    add_hypernyms(
+        _regions_with_labels,
+        _regions_with_hypernym_labels,
+        _object_labels_to_hypernyms)
+    # generate_md5hash()
 
     
-- 
GitLab