From 15aa9fc0e4b8e1e7871f63e9fa0a41f16f31e852 Mon Sep 17 00:00:00 2001 From: tgupta6 <tgupta6@illinois.edu> Date: Thu, 3 Nov 2016 11:21:04 -0500 Subject: [PATCH] eval and visualize script for the interpreted evaluation --- .../eval_interpret.py | 4 +- visual_util/visualize_relevance_interpret.py | 239 ++++++++++++++++++ 2 files changed, 241 insertions(+), 2 deletions(-) create mode 100644 visual_util/visualize_relevance_interpret.py diff --git a/answer_classifier_cached_features/eval_interpret.py b/answer_classifier_cached_features/eval_interpret.py index cf6b126..a418ec8 100644 --- a/answer_classifier_cached_features/eval_interpret.py +++ b/answer_classifier_cached_features/eval_interpret.py @@ -182,7 +182,7 @@ class eval_mgr(): dict_entry['relevance_scores'] = eval_vars_dict['relevance_prob_' + str(j)].tolist() - selected_region = np.argmax(dict_entry['relevance_scores'][0,:]) + selected_region = np.argmax(eval_vars_dict['relevance_prob_' + str(j)][0,:]) question_id = batch['question_id'][j] pred_answer, pred_score = self.get_pred_answer( @@ -197,7 +197,7 @@ class eval_mgr(): 'pred_obj_labels': pred_obj_labels[selected_region + j*constants.num_region_proposals], 'pred_atr_labels': pred_atr_labels[selected_region + j*constants.num_region_proposals], } - + if question_id not in self.seen_qids: self.seen_qids.add(question_id) self.results.append(result_entry) diff --git a/visual_util/visualize_relevance_interpret.py b/visual_util/visualize_relevance_interpret.py new file mode 100644 index 0000000..504269c --- /dev/null +++ b/visual_util/visualize_relevance_interpret.py @@ -0,0 +1,239 @@ +import ujson +import os +import csv +import numpy as np +import random +from matplotlib import cm + +import image_io +from html_writer import HtmlWriter +import pdb + +class RelevanceVisualizer(): + def __init__( + self, + eval_data_json, + eval_results_json, + anno_data_json, + image_dir, + region_dir, + output_dir, + data_type): + self.image_dir = image_dir + self.region_dir = region_dir + self.output_dir = output_dir + self.data_type = data_type + self.eval_data = self.read_json_file(eval_data_json) + self.eval_results_ = self.read_json_file(eval_results_json) + self.eval_results = dict() + for item in self.eval_results_: + self.eval_results[item['question_id']] = item + self.anno_data = self.read_json_file(anno_data_json) + self.html_filename = os.path.join(output_dir, 'index.html') + self.html_writer = HtmlWriter(self.html_filename) + + def read_json_file(self, filename): + print 'Reading {} ...'.format(filename) + with open(filename, 'r') as file: + return ujson.load(file) + + def get_image_name(self,qid): + image_id = self.anno_data[qid]['image_id'] + image_name = 'COCO_' + self.data_type + '_' + str(image_id).zfill(12) + image_path = os.path.join(image_dir, image_name + '.jpg') + return image_name, image_path + + def get_image(self, key): + image_name, image_path = self.get_image_name(key) + im = image_io.imread(image_path) + return im + + def get_bboxes(self, key): + image_name, image_path = self.get_image_name(key) + bbox_dir = os.path.join(self.region_dir, image_name) + bbox_csv = os.path.join(bbox_dir, 'edge_boxes.csv') + bboxes = [] + with open(bbox_csv, 'r') as csvfile: + bbox_reader = csv.DictReader( + csvfile, + delimiter=',', + fieldnames=['x', 'y', 'w', 'h', 'score']) + for bbox in bbox_reader: + bboxes.append(bbox) + return bboxes + + def get_pos_ans_rel_scores(self, key): + ans, score = self.eval_data[key]['positive_answer'].items()[0] + return ans, float(score), self.eval_data[key]['relevance_scores'][0] + + def get_pred_ans_rel_scores(self, key): + ans, score = self.eval_data[key]['positive_answer'].items()[0] + score = float(score) + pred_ans_score = (ans, score) + pred_ans_id = 0 + count = 1 + for ans, score in self.eval_data[key]['negative_answers'].items(): + score = float(score) + if score > pred_ans_score[1]: + pred_ans_score = (ans, score) + pred_ans_id = count + count += 1 + + return pred_ans_score[0], pred_ans_score[1], \ + self.eval_data[key]['relevance_scores'][pred_ans_id] + + def get_box_score_pairs(self, bboxes, scores): + pairs = [] + for i, bbox in enumerate(bboxes): + pairs.append((bbox, scores[i])) + return pairs + + def create_relevance_map(self, key, mode='pos'): + image_name, image_path = self.get_image_name(key) + im = image_io.imread(image_path) + if len(im.shape)==2: + im_h, im_w = im.shape + im_ = im + im = np.zeros([im_h, im_w, 3]) + for i in xrange(3): + im[:,:,i] = im_ + + bboxes = self.get_bboxes(key) + if mode=='pos': + ans, ans_score, rel_scores = self.get_pos_ans_rel_scores(key) + elif mode=='pred': + ans, ans_score, rel_scores = self.get_pred_ans_rel_scores(key) + else: + print 'mode can only take values {\'pred\' or \'pos\'}' + raise + box_score_pairs = self.get_box_score_pairs(bboxes, rel_scores) + rel_map = np.zeros(im.shape[0:2]) + for box, score in box_score_pairs: + box_map = self.make_boxmap(box, im.shape[0:2]) + rel_map = rel_map + score*box_map + # gauss_map = self.make_gaussian(box, im.shape[0:2]) + # rel_map = np.maximum(rel_map, score*box_map) + rel_map_ =cm.jet(np.uint8(rel_map*255))[:,:,0:3]*255 + # im_rel_map = np.uint8(0.5*im+0.5*rel_map_) + im_rel_map = np.uint8(0.0*im + 1.0*np.tile(rel_map[:,:,None], [1,1,3])*im) + return rel_map, im_rel_map, im, ans, ans_score + + def make_gaussian(self, box, im_size): + im_h, im_w = im_size + x = np.arange(im_w) + y = np.arange(im_h) + xx, yy = np.meshgrid(x, y) + sigma_x = float(box['w'])/4.0 + sigma_y = float(box['h'])/4.0 + cx = float(box['x']) + float(box['w'])/2.0 + cy = float(box['y']) + float(box['w'])/2.0 + g = np.exp(-((xx-cx)**2/(2*sigma_x**2)) - ((yy-cy)**2/(2*sigma_y**2))) + return g + + def make_boxmap(self, box, im_size): + im_h, im_w = im_size + x = int(box['x']) + y = int(box['y']) + w = int(box['w']) + h = int(box['h']) + map = np.zeros(im_size) + map[y-1:y+h-1, x-1:x+w-1] = 1.0 + return map + + def write_html(self): + col_dict = { + 0: 'Question', + 1: 'Pos. Answer', + 2: 'Pred. Answer', + 3: 'Pred. Relevance', + 4: 'Pred. Objects', + 5: 'Pred. Attributes', + 6: 'Question Id', + 7: 'Image' + } + self.html_writer.add_element(col_dict) + random.seed(0) + qids = sorted(rel_vis.eval_data.keys()) + random.shuffle(qids) + for qid in qids[0:5000]: + question = rel_vis.anno_data[qid]['question'] + + pred_rel, pred_im_rel, im, pred_ans, pred_score = rel_vis.create_relevance_map( + qid, + mode='pred') + pos_rel, pos_im_rel, im, pos_ans, pos_score = rel_vis.create_relevance_map( + qid, + mode='pos') + + if np.max(pred_rel) < 0.5: + continue + + pred_im_name = 'pred_rel_' + qid + '.jpg' + pos_im_name = 'pos_rel_' + qid + '.jpg' + im_name = 'im_' + qid + '.jpg' + + pred_rel_filename = os.path.join(self.output_dir, pred_im_name) + pos_rel_filename = os.path.join(self.output_dir, pos_im_name) + im_filename = os.path.join(self.output_dir, im_name) + + image_io.imwrite(pred_im_rel, pred_rel_filename) + image_io.imwrite(pos_im_rel, pos_rel_filename) + image_io.imwrite(im, im_filename) + + im_h, im_w = pred_rel.shape[0:2] + col_dict = { + 0 : question, + 1 : pos_ans + ': ' + str(pos_score), + 2 : pred_ans + ': ' + str(pred_score), + 3 : self.html_writer.image_tag(pred_im_name, im_h, im_w), + 4 : self.eval_results[int(qid)]['pred_obj_labels'], + 5 : self.eval_results[int(qid)]['pred_atr_labels'], + 6 : qid, + 7 : self.html_writer.image_tag(im_name, im_h, im_w), + } + + self.html_writer.add_element(col_dict) + + self.html_writer.close_file() + + +if __name__=='__main__': + data_type = 'val2014' + data_dir = '/home/ssd/VQA' + + image_dir = os.path.join(data_dir, data_type) + + region_dir = os.path.join( + data_dir, + data_type + '_cropped_large') + + anno_data_json = os.path.join( + data_dir, + 'mscoco_val2014_annotations_with_parsed_questions.json') + + exp_dir = '/home/tanmay/Code/GenVQA/Exp_Results/models_cvpr/' + \ + 'ans_through_obj_atr_rel_bin_feats/' + + eval_data_json = os.path.join( + exp_dir, + 'answer_classifiers/Results/eval_val_data.json') + + eval_results_json = os.path.join( + exp_dir, + 'answer_classifiers/Results/eval_val_results.json') + + output_dir = os.path.join(exp_dir, 'qual_results_val_interpret') + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + rel_vis = RelevanceVisualizer( + eval_data_json, + eval_results_json, + anno_data_json, + image_dir, + region_dir, + output_dir, + data_type) + + rel_vis.write_html() + -- GitLab