Skip to content
Snippets Groups Projects
Commit 22de2cfa authored by tgupta6's avatar tgupta6
Browse files

eval_interpret

parent 60ed8b9c
No related branches found
No related tags found
No related merge requests found
# from word2vec.word_vector_management import word_vector_manager
# import object_attribute_classifier.inference as feature_graph
# import region_relevance_network.inference as relevance_graph
# import answer_classifier.inference as answer_graph
from tftools import var_collect, placeholder_management
import tftools.data
import losses
import constants
import tftools.var_collect as var_collect
import data.vqa_cached_features as vqa_data
import answer_classifier_cached_features.train as train
import numpy as np
import pdb
import ujson
import tensorflow as tf
def create_initializer(graph, sess, model):
class initializer():
def __init__(self):
with graph.tf_graph.as_default():
model_vars = graph.vars_to_save
model_restorer = tf.train.Saver(model_vars)
model_restorer.restore(sess, model)
not_to_init = model_vars
all_vars = tf.all_variables()
other_vars = [var for var in all_vars
if var not in not_to_init]
var_collect.print_var_list(
other_vars,
'vars_to_init')
self.init = tf.initialize_variables(other_vars)
def initialize(self):
sess.run(self.init)
return initializer()
def create_batch_generator(mode):
if mode=='val':
vqa_resnet_feat_dir = constants.vqa_val_resnet_feat_dir
vqa_anno = constants.vqa_val_anno
qids_json = constants.vqa_val_qids
else:
print "mode needs to be one of {'val'}, found " + mode
data_mgr = vqa_data.data(
vqa_resnet_feat_dir,
vqa_anno,
qids_json,
constants.vocab_json,
constants.vqa_answer_vocab_json,
constants.object_labels_json,
constants.attribute_labels_json,
constants.image_size,
constants.num_region_proposals,
constants.num_negative_answers,
resnet_feat_dim=constants.resnet_feat_dim)
num_questions = len(data_mgr.qids)
index_generator = tftools.data.sequential(
constants.answer_batch_size,
num_questions,
1,
0)
batch_generator = tftools.data.async_batch_generator(
data_mgr,
index_generator,
constants.answer_queue_size)
return batch_generator
def create_feed_dict_creator(plh, num_neg_answers):
def feed_dict_creator(batch):
vqa_batch = batch
batch_size = len(vqa_batch['question'])
# Create vqa inputs
inputs = {
'region_feats': np.concatenate(vqa_batch['region_feats'], axis=0),
'positive_answer': vqa_batch['positive_answer'],
}
for i in xrange(4):
bin_name = 'bin_' + str(i)
inputs[bin_name] = [
vqa_batch['question'][j][bin_name] for j in xrange(batch_size)]
for i in xrange(num_neg_answers):
answer_name = 'negative_answer_' + str(i)
inputs[answer_name] = [
vqa_batch['negative_answers'][j][i] for j in xrange(batch_size)]
inputs['positive_nouns'] = [
a + b for a, b in zip(
vqa_batch['question_nouns'],
vqa_batch['positive_answer_nouns'])]
inputs['positive_adjectives'] = [
a + b for a, b in zip(
vqa_batch['question_adjectives'],
vqa_batch['positive_answer_adjectives'])]
inputs['positive_nouns_identity'] = [
vqa_batch['nouns_identity'][j][0] for j in xrange(batch_size)]
inputs['positive_adjectives_identity'] = [
vqa_batch['adjectives_identity'][j][0] for j in xrange(batch_size)]
for i in xrange(num_neg_answers):
name = 'negative_nouns_' + str(i)
list_ith_negative_answer_nouns = [
vqa_batch['negative_answers_nouns'][j][i]
for j in xrange(batch_size)]
inputs[name] = [
a + b for a, b in zip(
vqa_batch['question_nouns'],
list_ith_negative_answer_nouns)]
name = 'negative_adjectives_' + str(i)
list_ith_negative_answer_adjectives = [
vqa_batch['negative_answers_adjectives'][j][i]
for j in xrange(batch_size)]
inputs[name] = [
a + b for a, b in zip(
vqa_batch['question_adjectives'],
list_ith_negative_answer_adjectives)]
name = 'negative_nouns_identity_' + str(i)
inputs[name] = [
vqa_batch['nouns_identity'][j][i+1] for j in xrange(batch_size)]
name = 'negative_adjectives_identity_' + str(i)
inputs[name] = [
vqa_batch['adjectives_identity'][j][i+1] for j in xrange(batch_size)]
inputs['yes_no_num_feat'] = vqa_batch['yes_no_num_feat']
inputs['keep_prob'] = 1.0
return plh.get_feed_dict(inputs)
return feed_dict_creator
class eval_mgr():
def __init__(
self,
eval_data_json,
results_json,
inv_object_labels_dict,
inv_attribute_labels_dict,
):
self.eval_data_json= eval_data_json
self.results_json = results_json
self.inv_object_labels_dict = inv_object_labels_dict
self.inv_attribute_labels_dict = inv_attribute_labels_dict
self.eval_data = dict()
self.correct = 0
self.total = 0
self.results = []
self.seen_qids = set()
def eval(self, iter, eval_vars_dict, batch):
batch_size = len(batch['question_unencoded'])
k = 10
pred_obj_labels = self.get_top_k_labels(eval_vars_dict['object_prob'],k)
pred_atr_labels = self.get_top_k_labels(eval_vars_dict['attribute_prob'],k,'atr')
for j in xrange(batch_size):
dict_entry = dict()
dict_entry['question'] = batch['question_unencoded'][j]
dict_entry['positive_answer'] = {
batch['positive_answer_unencoded'][j]:
str(eval_vars_dict['answer_score_' + str(j)][0,0])}
dict_entry['negative_answers'] = dict()
for i in xrange(len(batch['negative_answers_unencoded'][j])):
answer = batch['negative_answers_unencoded'][j][i]
dict_entry['negative_answers'][answer] = \
str(eval_vars_dict['answer_score_' + str(j)][0,i+1])
dict_entry['relevance_scores'] = eval_vars_dict['relevance_prob_' + str(j)].tolist()
selected_region = np.argmax(dict_entry['relevance_scores'][0,:])
question_id = batch['question_id'][j]
pred_answer, pred_score = self.get_pred_answer(
[batch['positive_answer_unencoded'][j]] + \
batch['negative_answers_unencoded'][j],
eval_vars_dict['answer_score_' + str(j)][0,:].tolist()
)
result_entry = {
'question_id': int(question_id),
'answer': pred_answer,
'pred_obj_labels': pred_obj_labels[selected_region + j*constants.num_region_proposals]
'pred_atr_labels': pred_atr_labels[selected_region + j*constants.num_region_proposals]
}
if question_id not in self.seen_qids:
self.seen_qids.add(question_id)
self.results.append(result_entry)
else:
print 'Already evaluated on this sample'
self.eval_data[str(question_id)] = dict_entry
# print dict_entry
self.total += batch_size
self.correct += eval_vars_dict['accuracy']*batch_size
self.print_accuracy()
if iter%100==0:
self.write_data()
def get_pred_answer(self, answers, scores):
pred_answer = ''
pred_score = -1e5
for answer, score in zip(answers, scores):
if score > pred_score:
pred_score = score
pred_answer = answer
return pred_answer, pred_score
def get_top_k_labels(self, prob, k, type='obj'):
num_samples, num_classes = prob.shape
top_k_labels = [None]*num_samples
for i in xrange(num_samples):
top_k = np.argsort(prob[i,:]).tolist()[-1:-1-k:-1]
top_k_labels[i] = []
for idx in top_k:
if type=='obj':
top_k_labels[i] += [self.inv_object_labels_dict[idx]]
elif type=='atr':
top_k_labels[i] += [self.inv_attribute_labels_dict[idx]]
return top_k_labels
def is_correct(self, answer_scores):
max_id = np.argmax(answer_scores, 1)
if max_id[0]==0:
return True
def print_accuracy(self):
print 'Total: {} Correct: {} Accuracy: {}'.format(
self.total,
self.correct,
self.correct/float(self.total))
def write_data(self):
with open(self.eval_data_json, 'w') as file:
ujson.dump(self.eval_data, file, indent=4, sort_keys=True)
with open(self.results_json, 'w') as file:
ujson.dump(self.results, file, indent=4, sort_keys=True)
def eval(
batch_generator,
sess,
initializer,
vars_to_eval_dict,
feed_dict_creator,
evaluator):
vars_to_eval_names = []
vars_to_eval = []
for var_name, var in vars_to_eval_dict.items():
vars_to_eval_names += [var_name]
vars_to_eval += [var]
with sess.as_default():
initializer.initialize()
iter = 0
for batch in batch_generator:
print '---'
print 'Iter: {}'.format(iter)
feed_dict = feed_dict_creator(batch)
eval_vars = sess.run(
vars_to_eval,
feed_dict = feed_dict)
eval_vars_dict = {
var_name: eval_var for var_name, eval_var in
zip(vars_to_eval_names, eval_vars)}
evaluator.eval(iter, eval_vars_dict, batch)
iter+=1
evaluator.write_data()
if __name__=='__main__':
print 'Creating batch generator...'
batch_generator = create_batch_generator(constants.answer_eval_on)
print 'Creating computation graph...'
graph = train.graph_creator(
constants.tb_log_dir,
constants.answer_batch_size,
constants.image_size,
constants.num_negative_answers,
constants.answer_regularization_coeff,
constants.answer_batch_size*constants.num_region_proposals,
0,
0,
0,
constants.answer_obj_atr_loss_wt,
constants.answer_ans_loss_wt,
constants.answer_mil_loss_wt,
resnet_feat_dim=constants.resnet_feat_dim,
training=False)
print 'Starting a session...'
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.9
sess = tf.Session(config=config, graph=graph.tf_graph)
print 'Creating initializer...'
initializer = create_initializer(
graph,
sess,
constants.answer_model_to_eval)
print 'Creating feed dict creator...'
feed_dict_creator = create_feed_dict_creator(
graph.plh,
constants.num_negative_answers)
print 'Creating dict of vars to be evaluated...'
vars_to_eval_dict = {
'accuracy': graph.answer_accuracy,
'object_prob': graph.obj_atr_inference.object_prob,
'attribute_prob': graph.obj_atr_inference.attribute_prob,
}
for j in xrange(constants.answer_batch_size):
vars_to_eval_dict['answer_score_'+str(j)] = \
graph.answer_inference.answer_score[j]
vars_to_eval_dict['relevance_prob_'+str(j)] = \
graph.relevance_inference.answer_region_prob[j]
# Read inverse object labels
with open(constants.object_labels_json,'r') as file:
object_labels_dict = ujson.load(file)
inv_object_labels_dict = {int(v): k for k,v in object_labels_dict.items()}
# Read inverse attribute labels
with open(constants.attribute_labels_json,'r') as file:
attribute_labels_dict = ujson.load(file)
inv_attribute_labels_dict = {int(v): k for k,v in attribute_labels_dict.items()}
print 'Creating evaluation manager...'
evaluator = eval_mgr(
constants.answer_eval_data_json,
constants.answer_eval_results_json,
inv_object_labels_dict,
inv_attribute_labels_dict)
print 'Start training...'
eval(
batch_generator,
sess,
initializer,
vars_to_eval_dict,
feed_dict_creator,
evaluator)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment