Skip to content
Snippets Groups Projects
Commit 849c57e3 authored by tgupta6's avatar tgupta6
Browse files

bn in answer inference

parent 7f64b299
No related branches found
No related tags found
No related merge requests found
......@@ -95,9 +95,10 @@ class AnswerInference():
reuse_vars = reuse_vars)
self.per_region_answer_scores[j] = tf.nn.relu(
self.batch_norm(
layers.batch_norm(
self.per_region_answer_scores[j],
self.is_training))
tf.constant(self.is_training),
reuse_vars = reuse_vars))
self.per_region_answer_scores[j] = layers.conv2d(
self.per_region_answer_scores[j],
......
......@@ -551,6 +551,8 @@ def create_vqa_batch_generator():
constants.vqa_train_subset_qids,
constants.vocab_json,
constants.vqa_answer_vocab_json,
constants.object_labels_json,
constants.attribute_labels_json,
constants.image_size,
constants.num_region_proposals,
constants.num_negative_answers,
......@@ -850,6 +852,8 @@ if __name__=='__main__':
constants.num_object_labels,
constants.num_attribute_labels,
constants.answer_obj_atr_loss_wt,
constants.answer_ans_loss_wt,
constants.answer_mil_loss_wt,
resnet_feat_dim=constants.resnet_feat_dim,
training=True)
......
......@@ -5,7 +5,7 @@ def mkdir_if_not_exists(dir_name):
if not os.path.exists(dir_name):
os.mkdir(dir_name)
experiment_name = 'QA_classifier_joint_pretrain_wordvec_xform_select' #'QA_joint_pretrain_genome_split'
experiment_name = 'class_through_vqa_mil_ans' #'QA_joint_pretrain_genome_split'
# Global output directory (all subexperiments will be saved here)
global_output_dir = '/home/tanmay/Code/GenVQA/Exp_Results/VQA'
......
......@@ -107,9 +107,9 @@ class data():
batch['negative_answers_unencoded'] = negative_answers_unencoded
batch['positive_nouns'] = batch['question_nouns'] \
+ batch['positive_answer_nouns'])
+ batch['positive_answer_nouns']
batch['positive_adjectives'] = batch['question_adjectives'] \
+ batch['positive_answer_adjectives'])
+ batch['positive_answer_adjectives']
_, batch['positive_nouns_vec_enc'] = self.noun_to_obj_id(
batch['positive_nouns'],
......@@ -155,8 +155,11 @@ class data():
for i, id in enumerate(list_of_noun_ids):
if id in self.inv_vocab:
obj_ids[i] = int(self.obj_labels[self.inv_vocab[id]])
vec_enc[obj_ids[i]] = 1.0
if self.inv_vocab[id] in self.obj_labels:
obj_ids[i] = int(self.obj_labels[self.inv_vocab[id]])
vec_enc[0,obj_ids[i]] = 1.0
else:
obj_ids[i] = -1
else:
obj_ids[i] = -1
......@@ -171,8 +174,9 @@ class data():
for i, id in enumerate(list_of_adj_ids):
if id in self.inv_vocab:
atr_ids[i] = int(self.atr_labels[self.inv_vocab[id]])
vec_enc[atr_ids[i]] = 1.0
if self.inv_vocab[id] in self.atr_labels:
atr_ids[i] = int(self.atr_labels[self.inv_vocab[id]])
vec_enc[0,atr_ids[i]] = 1.0
else:
atr_ids[i] = -1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment