Skip to content
Snippets Groups Projects
Commit bdb5e93d authored by tgupta6's avatar tgupta6
Browse files

Answer prediction with region scoring; no obj or atr features used for reg scoring

parent 47787cce
No related branches found
No related tags found
No related merge requests found
......@@ -113,6 +113,8 @@ def eval(eval_params):
image_regions, questions, keep_prob, y, region_score= \
graph_creator.placeholder_inputs_ans(len(vocab), len(ans_vocab),
mode='gt')
pred_rel_score = graph_creator.rel_comp_graph(image_regions, questions,
1.0, len(vocab), batch_size)
y_pred_obj = graph_creator.obj_comp_graph(image_regions, 1.0)
obj_feat_op = g.get_operation_by_name('obj/conv2/obj_feat')
obj_feat = obj_feat_op.outputs[0]
......@@ -124,7 +126,10 @@ def eval(eval_params):
obj_feat, atr_feat, vocab,
inv_vocab, len(ans_vocab),
eval_params['mode'])
y_avg = graph_creator.aggregate_y_pred(y_pred, region_score, batch_size,
pred_rel_score_vec = tf.reshape(pred_rel_score,
[1, batch_size*ans_io_helper.num_proposals])
y_avg = graph_creator.aggregate_y_pred(y_pred, pred_rel_score_vec,
batch_size,
ans_io_helper.num_proposals,
len(ans_vocab))
......@@ -132,12 +137,13 @@ def eval(eval_params):
accuracy = graph_creator.evaluation(y, y_avg)
# Collect variables
rel_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='rel')
obj_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='obj')
atr_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='atr')
pretrained_vars, vars_to_train, vars_to_restore, vars_to_save, \
vars_to_init, vars_dict = ans_trainer \
.get_process_flow_vars(eval_params['mode'],
obj_vars, atr_vars,
obj_vars, atr_vars, rel_vars,
True)
# Restore model
......@@ -257,7 +263,7 @@ def create_html_file(outdir, test_anno_filename, regions_anno_filename,
if __name__=='__main__':
mode = 'q_reg'
mode = 'q_obj_atr_reg'
model_num = 9
ans_classifier_eval_params = {
'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
......@@ -265,8 +271,9 @@ if __name__=='__main__':
'regions_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/regions_anno.json',
'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
'image_regions_dir': '/mnt/ramdisk/image_regions',
'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier',
'model': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier/ans_classifier_' + mode + '-' + str(model_num),
'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier_w_Rel',
'rel_model': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier/rel_classifier-4',
'model': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier_w_Rel/ans_classifier_' + mode + '-' + str(model_num),
'mode' : mode,
'batch_size': 20,
'test_start_id': 94645,
......
......@@ -20,7 +20,7 @@ val_start_id = 89645
val_set_size = 5000
val_set_size_small = 500
def get_process_flow_vars(mode, obj_vars, atr_vars, fine_tune):
def get_process_flow_vars(mode, obj_vars, atr_vars, rel_vars, fine_tune):
list_of_vars = [
'ans/word_embed/word_vecs',
'ans/conv1/W',
......@@ -94,14 +94,15 @@ def get_process_flow_vars(mode, obj_vars, atr_vars, fine_tune):
# Fine tune begining with a previous model
if fine_tune==True:
vars_to_restore = obj_vars + atr_vars + vars_to_train
vars_to_restore = obj_vars + atr_vars + rel_vars + vars_to_train
if not mode=='q':
vars_to_restore += [vars_dict['ans/word_embed/word_vecs']]
else:
vars_to_restore = obj_vars + atr_vars + pretrained_vars
vars_to_restore = obj_vars + atr_vars + rel_vars + pretrained_vars
# vars_to_restore = pretrained_vars
# Save trained vars
vars_to_save = obj_vars + atr_vars + vars_to_train + \
vars_to_save = obj_vars + atr_vars + rel_vars + vars_to_train + \
[vars_dict['ans/word_embed/word_vecs']]
# Initialize vars_to_init
......@@ -165,6 +166,7 @@ def train(train_params):
image_dir = train_params['image_dir']
image_regions_dir = train_params['image_regions_dir']
outdir = train_params['outdir']
rel_model = train_params['rel_model']
obj_atr_model = train_params['obj_atr_model']
batch_size = train_params['batch_size']
......@@ -192,7 +194,8 @@ def train(train_params):
image_regions, questions, keep_prob, y, region_score= \
graph_creator.placeholder_inputs_ans(len(vocab), len(ans_vocab),
mode='gt')
pred_rel_score = graph_creator.rel_comp_graph(image_regions, questions,
1.0, len(vocab), batch_size)
y_pred_obj = graph_creator.obj_comp_graph(image_regions, 1.0)
obj_feat_op = g.get_operation_by_name('obj/conv2/obj_feat')
obj_feat = obj_feat_op.outputs[0]
......@@ -200,17 +203,25 @@ def train(train_params):
atr_feat_op = g.get_operation_by_name('atr/conv2/atr_feat')
atr_feat = atr_feat_op.outputs[0]
# Restore obj and attribute classifier parameters
# Restore rel, obj and attribute classifier parameters
rel_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='rel')
obj_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='obj')
atr_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='atr')
rel_saver = tf.train.Saver(rel_vars)
obj_atr_saver = tf.train.Saver(obj_vars+atr_vars)
rel_saver.restore(sess, rel_model)
obj_atr_saver.restore(sess, obj_atr_model)
y_pred = graph_creator.ans_comp_graph(image_regions, questions, keep_prob,
obj_feat, atr_feat, vocab,
inv_vocab, len(ans_vocab),
train_params['mode'])
y_avg = graph_creator.aggregate_y_pred(y_pred, region_score, batch_size,
pred_rel_score_vec = tf.reshape(pred_rel_score,
[1, batch_size*ans_io_helper.num_proposals])
y_avg = graph_creator.aggregate_y_pred(y_pred,
pred_rel_score_vec, batch_size,
ans_io_helper.num_proposals,
len(ans_vocab))
......@@ -222,7 +233,7 @@ def train(train_params):
pretrained_vars, vars_to_train, vars_to_restore, vars_to_save, \
vars_to_init, vars_dict = \
get_process_flow_vars(train_params['mode'],
obj_vars, atr_vars,
obj_vars, atr_vars, rel_vars,
train_params['fine_tune'])
# Regularizers
......@@ -274,11 +285,13 @@ def train(train_params):
train_params['mode'] + '-' + \
str(train_params['start_model']))
start_epoch = train_params['start_model']+1
partial_restorer = tf.train.Saver(vars_to_restore)
else:
start_epoch = 0
partial_restorer = tf.train.Saver(pretrained_vars)
# Restore partial model
partial_restorer = tf.train.Saver(vars_to_restore)
# partial_restorer = tf.train.Saver(vars_to_restore)
if os.path.exists(partial_model):
partial_restorer.restore(sess, partial_model)
......@@ -293,7 +306,7 @@ def train(train_params):
# Initialize vars_to_init
all_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
optimizer_vars = [var for var in all_vars if var not in \
obj_vars + atr_vars + ans_vars]
obj_vars + atr_vars + rel_vars + ans_vars]
print('Optimizer Variables: ')
print([var.name for var in optimizer_vars])
......
......@@ -18,9 +18,9 @@ workflow = {
'eval_obj': False,
'train_atr': False,
'eval_atr': False,
'train_ans': False,
'train_rel': False,
'eval_rel': True,
'eval_rel': False,
'train_ans': True,
}
obj_classifier_train_params = {
......@@ -61,23 +61,6 @@ atr_classifier_eval_params = {
'html_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/html_dir',
}
ans_classifier_train_params = {
'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
'regions_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/regions_anno.json',
'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
'image_regions_dir': '/mnt/ramdisk/image_regions',
'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier',
'obj_atr_model': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/obj_atr_classifier-1',
'adam_lr' : 0.001,
'mode' : 'q_reg',
'crop_n_save_regions': False,
'max_epoch': 10,
'batch_size': 10,
'fine_tune': False,
'start_model': 4,
}
rel_classifier_train_params = {
'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
......@@ -107,6 +90,24 @@ rel_classifier_eval_params = {
'test_set_size': 143495-94645+1,
}
ans_classifier_train_params = {
'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
'regions_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/regions_anno.json',
'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
'image_regions_dir': '/mnt/ramdisk/image_regions',
'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier_w_Rel',
'rel_model': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier/rel_classifier-4',
'obj_atr_model': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/obj_atr_classifier-1',
'adam_lr' : 0.0001,
'mode' : 'q_obj_atr_reg',
'crop_n_save_regions': False,
'max_epoch': 10,
'batch_size': 10,
'fine_tune': True,
'start_model': 4,
}
if __name__=='__main__':
if workflow['train_obj']:
obj_trainer.train(obj_classifier_train_params)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment