Skip to content
Snippets Groups Projects
Commit db4a0395 authored by tgupta6's avatar tgupta6
Browse files

model with elementwise multiplication and linear classifier working

parent b6506c90
No related branches found
No related tags found
No related merge requests found
......@@ -23,40 +23,34 @@ class AnswerInference():
self.num_answers = num_answers
self.keep_prob = keep_prob
with tf.variable_scope('answer_graph'):
self.question = [None]*self.batch_size
for j in xrange(self.batch_size):
self.question[j] = tf.reshape(
self.question_vert_concat[j],
[1, -1])
self.answers = self.answers_vert_concat
self.qa_proj = [None]*self.batch_size
self.obj_atr_qa_elementwise_prod = [None]*self.batch_size
for j in xrange(self.batch_size):
self.qa_proj[j] = self.project_question_answer(
self.question[j],
self.answers[j],
space_dim,
'question_answer_projection',
None if j==0 else True)
self.obj_atr_qa_elementwise_prod[j] = self.elementwise_product(
self.object_feat[j],
self.attribute_feat[j],
self.question_vert_concat[j],
self.answers_vert_concat[j])
self.obj_atr_proj = [None]*self.batch_size
for j in xrange(self.batch_size):
self.obj_atr_proj[j] = self.project_object_attribute(
self.object_feat[j],
self.attribute_feat[j],
space_dim,
'object_attribute_projection',
None if j==0 else True)
self.per_region_answer_scores = [None]*self.batch_size
for j in xrange(self.batch_size):
self.per_region_answer_scores[j] = tf.matmul(
self.obj_atr_proj[j],
tf.transpose(self.qa_proj[j]),
name='per_region_answer_scores')
if j==0:
reuse_vars = False
else:
reuse_vars = True
self.per_region_answer_scores[j] = layers.conv2d(
tf.expand_dims(
self.obj_atr_qa_elementwise_prod[j],
0),
1,
1,
'per_region_ans_score_conv',
func = None,
reuse_vars = reuse_vars)
self.per_region_answer_scores[j] = tf.squeeze(
self.per_region_answer_scores[j],
[0,3])
self.per_region_answer_prob = [None]*self.batch_size
self.answer_score = [None]*self.batch_size
......@@ -73,6 +67,25 @@ class AnswerInference():
answer_score_tmp,
0,
keep_dims=True)
def elementwise_product(self, obj_feat, atr_feat, ques_feat, ans_feat):
tiled_ques = tf.tile(tf.reshape(ques_feat,[1, -1]),[self.num_answers,1])
qa_feat = tf.concat(
1,
[tiled_ques, ans_feat])
qa_feat = tf.tile(qa_feat, [1,2])
obj_atr_feat = tf.concat(
1,
[obj_feat, atr_feat])
obj_atr_feat = tf.tile(obj_atr_feat, [1,5])
obj_atr_feat = tf.expand_dims(obj_atr_feat,1)
feat = obj_atr_feat*qa_feat
return feat
def concat_object_attribute(self, object_feat, attribute_feat):
object_attribute = tf.concat(
......
......@@ -555,7 +555,7 @@ class attach_optimizer():
self.optimizer.add_variables(
self.graph.object_attribute_vars + self.graph.word_vec_vars,
learning_rate = 0.0*self.lr)
learning_rate = 1.0*self.lr)
self.optimizer.add_variables(
......
......@@ -5,7 +5,7 @@ def mkdir_if_not_exists(dir_name):
if not os.path.exists(dir_name):
os.mkdir(dir_name)
experiment_name = 'object_attribute_classifier_wordvec_xform' #'QA_joint_pretrain_genome_split'
experiment_name = 'QA_classifier_wordvec_xform' #'QA_joint_pretrain_genome_split'
# Global output directory (all subexperiments will be saved here)
global_output_dir = '/home/tanmay/Code/GenVQA/Exp_Results/VQA'
......@@ -104,7 +104,7 @@ region_batch_size = 200
region_num_epochs = 6
region_queue_size = 400
region_regularization_coeff = 1e-5
region_lr = 1e-4
region_lr = 1e-3
region_log_every_n_iter = 500
region_output_dir = os.path.join(
global_experiment_dir,
......@@ -204,8 +204,8 @@ answer_output_dir = os.path.join(
mkdir_if_not_exists(answer_output_dir)
pretrained_model = '/home/tanmay/Code/GenVQA/Exp_Results/VQA/' + \
'object_attribute_classifier_large_images_vqa_split/' + \
'object_attribute_classifiers/model-80000'
'object_attribute_classifier_wordvec_xform/' + \
'object_attribute_classifiers/model-102000'
answer_model = os.path.join(
answer_output_dir,
......@@ -215,7 +215,7 @@ answer_model = os.path.join(
num_regions_with_labels = 100
# Answer fine tune params
answer_fine_tune_from_iter = 18500
answer_fine_tune_from_iter = 13000
answer_fine_tune_from = answer_model + '-' + str(answer_fine_tune_from_iter)
# Answer eval params
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment