Skip to content
Snippets Groups Projects
Commit f79fea89 authored by Kevin Shih's avatar Kevin Shih
Browse files
parents 15fbeaec 6db1d617
No related branches found
No related tags found
No related merge requests found
......@@ -113,15 +113,17 @@ def eval(eval_params):
image_regions, questions, keep_prob, y, region_score= \
graph_creator.placeholder_inputs_ans(len(vocab), len(ans_vocab),
mode='gt')
pred_rel_score = graph_creator.rel_comp_graph(image_regions, questions,
1.0, len(vocab), batch_size)
y_pred_obj = graph_creator.obj_comp_graph(image_regions, 1.0)
obj_feat_op = g.get_operation_by_name('obj/conv2/obj_feat')
obj_feat = obj_feat_op.outputs[0]
y_pred_atr = graph_creator.atr_comp_graph(image_regions, 1.0, obj_feat)
atr_feat_op = g.get_operation_by_name('atr/conv2/atr_feat')
atr_feat = atr_feat_op.outputs[0]
pred_rel_score = graph_creator.rel_comp_graph(image_regions, questions,
y_pred_obj, y_pred_atr,
'q_obj_atr_reg',
1.0, len(vocab), batch_size)
y_pred = graph_creator.ans_comp_graph(image_regions, questions, keep_prob,
obj_feat, atr_feat, vocab,
inv_vocab, len(ans_vocab),
......@@ -263,7 +265,7 @@ def create_html_file(outdir, test_anno_filename, regions_anno_filename,
if __name__=='__main__':
mode = 'q_obj_atr_reg'
mode = 'q_obj_atr'
model_num = 9
ans_classifier_eval_params = {
'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
......@@ -272,7 +274,7 @@ if __name__=='__main__':
'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
'image_regions_dir': '/mnt/ramdisk/image_regions',
'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier_w_Rel',
'rel_model': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier/rel_classifier-4',
'rel_model': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier_Obj_Atr_Prob/rel_classifier_q_obj_atr_reg-4',
'model': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier_w_Rel/ans_classifier_' + mode + '-' + str(model_num),
'mode' : mode,
'batch_size': 20,
......
......@@ -95,6 +95,7 @@ def get_process_flow_vars(mode, obj_vars, atr_vars, rel_vars, fine_tune):
# Fine tune begining with a previous model
if fine_tune==True:
vars_to_restore = obj_vars + atr_vars + rel_vars + vars_to_train
pretrained_vars = vars_to_restore[:]
if not mode=='q':
vars_to_restore += [vars_dict['ans/word_embed/word_vecs']]
else:
......@@ -200,9 +201,15 @@ def train(train_params):
y_pred_atr = graph_creator.atr_comp_graph(image_regions, 1.0, obj_feat)
atr_feat_op = g.get_operation_by_name('atr/conv2/atr_feat')
atr_feat = atr_feat_op.outputs[0]
# pred_rel_score = graph_creator.rel_comp_graph(image_regions, questions,
# y_pred_obj, y_pred_atr,
# 'q_obj_atr_reg', 1.0,
# len(vocab), batch_size)
pred_rel_score = graph_creator.rel_comp_graph(image_regions, questions,
obj_feat, atr_feat,
1.0, len(vocab), batch_size)
obj_feat, atr_feat,
'q_obj_atr_reg', 1.0,
len(vocab), batch_size)
# Restore rel, obj and attribute classifier parameters
rel_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='rel')
obj_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='obj')
......
......@@ -51,9 +51,13 @@ def eval(eval_params):
atr_feat_op = g.get_operation_by_name('atr/conv2/atr_feat')
atr_feat = atr_feat_op.outputs[0]
y_pred = graph_creator.rel_comp_graph(image_regions, questions,
obj_feat, atr_feat, mode,
y_pred_obj, y_pred_atr, mode,
keep_prob, len(vocab), batch_size)
# y_pred = graph_creator.rel_comp_graph(image_regions, questions,
# obj_feat, atr_feat, mode,
# keep_prob, len(vocab), batch_size)
# Restore model
restorer = tf.train.Saver()
if os.path.exists(model):
......
......@@ -18,9 +18,9 @@ workflow = {
'eval_obj': False,
'train_atr': False,
'eval_atr': False,
'train_rel': True,
'train_rel': False,
'eval_rel': False,
'train_ans': False,
'train_ans': True,
}
obj_classifier_train_params = {
......@@ -87,7 +87,7 @@ rel_classifier_eval_params = {
'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier_Obj_Atr_Prob',
'model_basedir': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier_Obj_Atr_Prob',
'model_number': 4,
'mode': 'q_obj_atr_reg',
'mode': 'q_obj_atr',
'batch_size': 20,
'test_start_id': 94645,
'test_set_size': 143495-94645+1,
......@@ -100,15 +100,15 @@ ans_classifier_train_params = {
'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
'image_regions_dir': '/mnt/ramdisk/image_regions',
'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier_w_Rel',
'rel_model': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier/rel_classifier-4',
'rel_model': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier/rel_classifier_q_obj_atr-4',
'obj_atr_model': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/obj_atr_classifier-1',
'adam_lr' : 0.001,
'adam_lr' : 0.0001,
'mode' : 'q_obj_atr',
'crop_n_save_regions': False,
'max_epoch': 5,
'max_epoch': 10,
'batch_size': 10,
'fine_tune': False,
'start_model': 2,
'fine_tune': True,
'start_model': 4,
}
if __name__=='__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment