From 15aa9fc0e4b8e1e7871f63e9fa0a41f16f31e852 Mon Sep 17 00:00:00 2001
From: tgupta6 <tgupta6@illinois.edu>
Date: Thu, 3 Nov 2016 11:21:04 -0500
Subject: [PATCH] eval and visualize script for the interpreted evaluation

---
 .../eval_interpret.py                         |   4 +-
 visual_util/visualize_relevance_interpret.py  | 239 ++++++++++++++++++
 2 files changed, 241 insertions(+), 2 deletions(-)
 create mode 100644 visual_util/visualize_relevance_interpret.py

diff --git a/answer_classifier_cached_features/eval_interpret.py b/answer_classifier_cached_features/eval_interpret.py
index cf6b126..a418ec8 100644
--- a/answer_classifier_cached_features/eval_interpret.py
+++ b/answer_classifier_cached_features/eval_interpret.py
@@ -182,7 +182,7 @@ class eval_mgr():
             
             dict_entry['relevance_scores'] = eval_vars_dict['relevance_prob_' + str(j)].tolist()
 
-            selected_region = np.argmax(dict_entry['relevance_scores'][0,:])
+            selected_region = np.argmax(eval_vars_dict['relevance_prob_' + str(j)][0,:])
 
             question_id = batch['question_id'][j]
             pred_answer, pred_score = self.get_pred_answer(
@@ -197,7 +197,7 @@ class eval_mgr():
                 'pred_obj_labels': pred_obj_labels[selected_region + j*constants.num_region_proposals],
                 'pred_atr_labels': pred_atr_labels[selected_region + j*constants.num_region_proposals],
             }
-            
+
             if question_id not in self.seen_qids:
                 self.seen_qids.add(question_id)
                 self.results.append(result_entry)
diff --git a/visual_util/visualize_relevance_interpret.py b/visual_util/visualize_relevance_interpret.py
new file mode 100644
index 0000000..504269c
--- /dev/null
+++ b/visual_util/visualize_relevance_interpret.py
@@ -0,0 +1,239 @@
+import ujson
+import os
+import csv
+import numpy as np
+import random
+from matplotlib import cm
+
+import image_io
+from html_writer import HtmlWriter
+import pdb
+
+class RelevanceVisualizer():
+    def __init__(
+            self,
+            eval_data_json,
+            eval_results_json,
+            anno_data_json,
+            image_dir,
+            region_dir,
+            output_dir,
+            data_type):
+        self.image_dir = image_dir
+        self.region_dir = region_dir
+        self.output_dir = output_dir
+        self.data_type = data_type
+        self.eval_data = self.read_json_file(eval_data_json)
+        self.eval_results_ = self.read_json_file(eval_results_json)
+        self.eval_results = dict()
+        for item in self.eval_results_:
+            self.eval_results[item['question_id']] = item
+        self.anno_data = self.read_json_file(anno_data_json)
+        self.html_filename = os.path.join(output_dir, 'index.html')
+        self.html_writer = HtmlWriter(self.html_filename)
+
+    def read_json_file(self, filename):
+        print 'Reading {} ...'.format(filename)
+        with open(filename, 'r') as file:
+            return ujson.load(file)
+
+    def get_image_name(self,qid):
+        image_id = self.anno_data[qid]['image_id']
+        image_name = 'COCO_' + self.data_type + '_' + str(image_id).zfill(12)
+        image_path = os.path.join(image_dir, image_name + '.jpg')
+        return image_name, image_path
+
+    def get_image(self, key):
+        image_name, image_path = self.get_image_name(key)
+        im = image_io.imread(image_path)
+        return im
+
+    def get_bboxes(self, key):
+        image_name, image_path = self.get_image_name(key)
+        bbox_dir = os.path.join(self.region_dir, image_name)
+        bbox_csv = os.path.join(bbox_dir, 'edge_boxes.csv')
+        bboxes = []
+        with open(bbox_csv, 'r') as csvfile:
+            bbox_reader = csv.DictReader(
+                csvfile, 
+                delimiter=',', 
+                fieldnames=['x', 'y', 'w', 'h', 'score'])
+            for bbox in bbox_reader:
+                bboxes.append(bbox)
+        return bboxes
+        
+    def get_pos_ans_rel_scores(self, key):
+        ans, score = self.eval_data[key]['positive_answer'].items()[0]
+        return ans, float(score), self.eval_data[key]['relevance_scores'][0]
+        
+    def get_pred_ans_rel_scores(self, key):
+        ans, score = self.eval_data[key]['positive_answer'].items()[0]
+        score = float(score)
+        pred_ans_score = (ans, score)
+        pred_ans_id = 0
+        count = 1
+        for ans, score in self.eval_data[key]['negative_answers'].items():
+            score = float(score)
+            if score > pred_ans_score[1]:
+                pred_ans_score = (ans, score)
+                pred_ans_id = count
+            count += 1
+
+        return pred_ans_score[0], pred_ans_score[1], \
+            self.eval_data[key]['relevance_scores'][pred_ans_id]
+
+    def get_box_score_pairs(self, bboxes, scores):
+        pairs = []
+        for i, bbox in enumerate(bboxes):
+            pairs.append((bbox, scores[i]))
+        return pairs
+
+    def create_relevance_map(self, key, mode='pos'):
+        image_name, image_path = self.get_image_name(key)
+        im = image_io.imread(image_path)
+        if len(im.shape)==2:
+            im_h, im_w = im.shape
+            im_ = im
+            im = np.zeros([im_h, im_w, 3])
+            for i in xrange(3):
+                im[:,:,i] = im_
+
+        bboxes = self.get_bboxes(key)
+        if mode=='pos':
+            ans, ans_score, rel_scores = self.get_pos_ans_rel_scores(key)
+        elif mode=='pred':
+            ans, ans_score, rel_scores = self.get_pred_ans_rel_scores(key)
+        else:
+            print 'mode can only take values {\'pred\' or \'pos\'}'
+            raise
+        box_score_pairs = self.get_box_score_pairs(bboxes, rel_scores)
+        rel_map = np.zeros(im.shape[0:2])
+        for box, score in box_score_pairs:
+            box_map = self.make_boxmap(box, im.shape[0:2])
+            rel_map = rel_map + score*box_map
+            # gauss_map = self.make_gaussian(box, im.shape[0:2])
+            # rel_map = np.maximum(rel_map, score*box_map)
+        rel_map_ =cm.jet(np.uint8(rel_map*255))[:,:,0:3]*255
+        # im_rel_map = np.uint8(0.5*im+0.5*rel_map_)
+        im_rel_map = np.uint8(0.0*im + 1.0*np.tile(rel_map[:,:,None], [1,1,3])*im)
+        return rel_map, im_rel_map, im, ans, ans_score
+
+    def make_gaussian(self, box, im_size):
+        im_h, im_w = im_size
+        x = np.arange(im_w)
+        y = np.arange(im_h)
+        xx, yy = np.meshgrid(x, y)
+        sigma_x = float(box['w'])/4.0
+        sigma_y = float(box['h'])/4.0
+        cx = float(box['x']) + float(box['w'])/2.0
+        cy = float(box['y']) + float(box['w'])/2.0
+        g = np.exp(-((xx-cx)**2/(2*sigma_x**2)) - ((yy-cy)**2/(2*sigma_y**2)))
+        return g
+
+    def make_boxmap(self, box, im_size):
+        im_h, im_w = im_size
+        x = int(box['x'])
+        y = int(box['y'])
+        w = int(box['w'])
+        h = int(box['h'])
+        map = np.zeros(im_size)
+        map[y-1:y+h-1, x-1:x+w-1] = 1.0
+        return map
+
+    def write_html(self):
+        col_dict = {
+            0: 'Question',
+            1: 'Pos. Answer',
+            2: 'Pred. Answer',
+            3: 'Pred. Relevance',
+            4: 'Pred. Objects',
+            5: 'Pred. Attributes',
+            6: 'Question Id',
+            7: 'Image'
+        }
+        self.html_writer.add_element(col_dict)
+        random.seed(0)
+        qids = sorted(rel_vis.eval_data.keys())
+        random.shuffle(qids)
+        for qid in qids[0:5000]:
+            question = rel_vis.anno_data[qid]['question']
+
+            pred_rel, pred_im_rel, im, pred_ans, pred_score = rel_vis.create_relevance_map(
+                qid,
+                mode='pred')
+            pos_rel, pos_im_rel, im, pos_ans, pos_score = rel_vis.create_relevance_map(
+                qid,
+                mode='pos')
+
+            if np.max(pred_rel) < 0.5:
+                continue
+
+            pred_im_name = 'pred_rel_' + qid + '.jpg'
+            pos_im_name = 'pos_rel_' + qid + '.jpg'
+            im_name = 'im_' + qid + '.jpg'
+
+            pred_rel_filename = os.path.join(self.output_dir, pred_im_name)
+            pos_rel_filename = os.path.join(self.output_dir, pos_im_name)
+            im_filename = os.path.join(self.output_dir, im_name)
+
+            image_io.imwrite(pred_im_rel, pred_rel_filename)
+            image_io.imwrite(pos_im_rel, pos_rel_filename)
+            image_io.imwrite(im, im_filename)
+
+            im_h, im_w = pred_rel.shape[0:2]
+            col_dict = {
+                0 : question,
+                1 : pos_ans + ': ' + str(pos_score),      
+                2 : pred_ans + ': ' + str(pred_score),               
+                3 : self.html_writer.image_tag(pred_im_name, im_h, im_w),
+                4 : self.eval_results[int(qid)]['pred_obj_labels'], 
+                5 : self.eval_results[int(qid)]['pred_atr_labels'], 
+                6 : qid,
+                7 : self.html_writer.image_tag(im_name, im_h, im_w),
+            }
+            
+            self.html_writer.add_element(col_dict)
+
+        self.html_writer.close_file()
+
+
+if __name__=='__main__':
+    data_type = 'val2014'
+    data_dir = '/home/ssd/VQA'
+
+    image_dir = os.path.join(data_dir, data_type)
+
+    region_dir = os.path.join(
+        data_dir,
+        data_type + '_cropped_large')
+
+    anno_data_json = os.path.join(
+        data_dir, 
+        'mscoco_val2014_annotations_with_parsed_questions.json')
+
+    exp_dir = '/home/tanmay/Code/GenVQA/Exp_Results/models_cvpr/' + \
+              'ans_through_obj_atr_rel_bin_feats/'
+
+    eval_data_json = os.path.join(
+        exp_dir,
+        'answer_classifiers/Results/eval_val_data.json')
+
+    eval_results_json = os.path.join(
+        exp_dir,
+        'answer_classifiers/Results/eval_val_results.json')
+
+    output_dir = os.path.join(exp_dir, 'qual_results_val_interpret')
+    if not os.path.exists(output_dir):
+        os.mkdir(output_dir)
+
+    rel_vis = RelevanceVisualizer(
+        eval_data_json,
+        eval_results_json,
+        anno_data_json,
+        image_dir,
+        region_dir,
+        output_dir,
+        data_type)
+
+    rel_vis.write_html()
+
-- 
GitLab