Skip to content
Snippets Groups Projects
Commit adfc0ab6 authored by tgupta6's avatar tgupta6
Browse files

make ans eval and train compatible with each other

parent ab98b61f
No related branches found
No related tags found
No related merge requests found
...@@ -51,6 +51,8 @@ def create_batch_generator(mode): ...@@ -51,6 +51,8 @@ def create_batch_generator(mode):
qids_json, qids_json,
constants.vocab_json, constants.vocab_json,
constants.vqa_answer_vocab_json, constants.vqa_answer_vocab_json,
constants.object_labels_json,
constants.attribute_labels_json,
constants.image_size, constants.image_size,
constants.num_region_proposals, constants.num_region_proposals,
constants.num_negative_answers, constants.num_negative_answers,
...@@ -267,13 +269,15 @@ if __name__=='__main__': ...@@ -267,13 +269,15 @@ if __name__=='__main__':
0, 0,
0, 0,
constants.answer_obj_atr_loss_wt, constants.answer_obj_atr_loss_wt,
constants.answer_ans_loss_wt,
constants.answer_mil_loss_wt,
resnet_feat_dim=constants.resnet_feat_dim, resnet_feat_dim=constants.resnet_feat_dim,
training=False) training=False)
print 'Starting a session...' print 'Starting a session...'
config = tf.ConfigProto() config = tf.ConfigProto()
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.5 config.gpu_options.per_process_gpu_memory_fraction = 0.9
sess = tf.Session(config=config, graph=graph.tf_graph) sess = tf.Session(config=config, graph=graph.tf_graph)
print 'Creating initializer...' print 'Creating initializer...'
......
...@@ -15,7 +15,7 @@ import numpy as np ...@@ -15,7 +15,7 @@ import numpy as np
import pdb import pdb
from itertools import izip from itertools import izip
import tensorflow as tf import tensorflow as tf
import gc
class graph_creator(): class graph_creator():
def __init__( def __init__(
...@@ -484,23 +484,24 @@ class graph_creator(): ...@@ -484,23 +484,24 @@ class graph_creator():
["accuracy_answer"], ["accuracy_answer"],
self.moving_average_accuracy) self.moving_average_accuracy)
# object if self.training:
self.object_accuracy = self.add_object_accuracy_computation( # object
self.object_scores_with_labels, self.object_accuracy = self.add_object_accuracy_computation(
self.plh['object_labels']) self.object_scores_with_labels,
self.plh['object_labels'])
object_accuracy_summary = tf.scalar_summary(
"accuracy_object", object_accuracy_summary = tf.scalar_summary(
self.object_accuracy) "accuracy_object",
self.object_accuracy)
# attributes
self.attribute_accuracy = self.add_attribute_accuracy_computation( # attributes
self.attribute_scores_with_labels, self.attribute_accuracy = self.add_attribute_accuracy_computation(
self.plh['attribute_labels']) self.attribute_scores_with_labels,
self.plh['attribute_labels'])
attribute_accuracy_summary = tf.scalar_summary(
"accuracy_attribute", attribute_accuracy_summary = tf.scalar_summary(
self.attribute_accuracy) "accuracy_attribute",
self.attribute_accuracy)
def add_answer_accuracy_computation(self, scores): def add_answer_accuracy_computation(self, scores):
with tf.variable_scope('answer_accuracy'): with tf.variable_scope('answer_accuracy'):
...@@ -895,7 +896,10 @@ def train( ...@@ -895,7 +896,10 @@ def train(
zip(vars_to_eval_names, eval_vars)} zip(vars_to_eval_names, eval_vars)}
logger.log(iter, False, eval_vars_dict) logger.log(iter, False, eval_vars_dict)
iter+=1 iter+=1
if iter%8000==0 and iter!=0:
gc.collect()
logger.log(iter-1, True, eval_vars_dict) logger.log(iter-1, True, eval_vars_dict)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment