diff --git a/classifiers/answer_classifier/eval_ans_classifier.py b/classifiers/answer_classifier/eval_ans_classifier.py index b5f6ee65df4b97c4579d1ea2a53b7c6700049bcb..f0c66054ff6177982c9f7c8fbb377d77f9636c63 100644 --- a/classifiers/answer_classifier/eval_ans_classifier.py +++ b/classifiers/answer_classifier/eval_ans_classifier.py @@ -113,15 +113,17 @@ def eval(eval_params): image_regions, questions, keep_prob, y, region_score= \ graph_creator.placeholder_inputs_ans(len(vocab), len(ans_vocab), mode='gt') - pred_rel_score = graph_creator.rel_comp_graph(image_regions, questions, - 1.0, len(vocab), batch_size) + y_pred_obj = graph_creator.obj_comp_graph(image_regions, 1.0) obj_feat_op = g.get_operation_by_name('obj/conv2/obj_feat') obj_feat = obj_feat_op.outputs[0] y_pred_atr = graph_creator.atr_comp_graph(image_regions, 1.0, obj_feat) atr_feat_op = g.get_operation_by_name('atr/conv2/atr_feat') atr_feat = atr_feat_op.outputs[0] - + pred_rel_score = graph_creator.rel_comp_graph(image_regions, questions, + y_pred_obj, y_pred_atr, + 'q_obj_atr_reg', + 1.0, len(vocab), batch_size) y_pred = graph_creator.ans_comp_graph(image_regions, questions, keep_prob, obj_feat, atr_feat, vocab, inv_vocab, len(ans_vocab), @@ -263,7 +265,7 @@ def create_html_file(outdir, test_anno_filename, regions_anno_filename, if __name__=='__main__': - mode = 'q_obj_atr_reg' + mode = 'q_obj_atr' model_num = 9 ans_classifier_eval_params = { 'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json', @@ -272,7 +274,7 @@ if __name__=='__main__': 'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images', 'image_regions_dir': '/mnt/ramdisk/image_regions', 'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier_w_Rel', - 'rel_model': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier/rel_classifier-4', + 'rel_model': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier_Obj_Atr_Prob/rel_classifier_q_obj_atr_reg-4', 'model': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier_w_Rel/ans_classifier_' + mode + '-' + str(model_num), 'mode' : mode, 'batch_size': 20, diff --git a/classifiers/answer_classifier/train_ans_classifier.py b/classifiers/answer_classifier/train_ans_classifier.py index 54531430f7a2b8e4d5d17e1949fbaf18c8c6b523..6d4213eae4ab8f51de2daf44c26ff39c2292832b 100644 --- a/classifiers/answer_classifier/train_ans_classifier.py +++ b/classifiers/answer_classifier/train_ans_classifier.py @@ -95,6 +95,7 @@ def get_process_flow_vars(mode, obj_vars, atr_vars, rel_vars, fine_tune): # Fine tune begining with a previous model if fine_tune==True: vars_to_restore = obj_vars + atr_vars + rel_vars + vars_to_train + pretrained_vars = vars_to_restore[:] if not mode=='q': vars_to_restore += [vars_dict['ans/word_embed/word_vecs']] else: @@ -200,9 +201,15 @@ def train(train_params): y_pred_atr = graph_creator.atr_comp_graph(image_regions, 1.0, obj_feat) atr_feat_op = g.get_operation_by_name('atr/conv2/atr_feat') atr_feat = atr_feat_op.outputs[0] + # pred_rel_score = graph_creator.rel_comp_graph(image_regions, questions, + # y_pred_obj, y_pred_atr, + # 'q_obj_atr_reg', 1.0, + # len(vocab), batch_size) pred_rel_score = graph_creator.rel_comp_graph(image_regions, questions, - obj_feat, atr_feat, - 1.0, len(vocab), batch_size) + obj_feat, atr_feat, + 'q_obj_atr_reg', 1.0, + len(vocab), batch_size) + # Restore rel, obj and attribute classifier parameters rel_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='rel') obj_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='obj') diff --git a/classifiers/region_ranker/eval_rel_classifier.py b/classifiers/region_ranker/eval_rel_classifier.py index 9e664c9f92f875b59629c1804ac288bc2e7c5fca..49046f51ad030decc8f96ccbce40c5e11a47c57c 100644 --- a/classifiers/region_ranker/eval_rel_classifier.py +++ b/classifiers/region_ranker/eval_rel_classifier.py @@ -51,9 +51,13 @@ def eval(eval_params): atr_feat_op = g.get_operation_by_name('atr/conv2/atr_feat') atr_feat = atr_feat_op.outputs[0] y_pred = graph_creator.rel_comp_graph(image_regions, questions, - obj_feat, atr_feat, mode, + y_pred_obj, y_pred_atr, mode, keep_prob, len(vocab), batch_size) + # y_pred = graph_creator.rel_comp_graph(image_regions, questions, + # obj_feat, atr_feat, mode, + # keep_prob, len(vocab), batch_size) + # Restore model restorer = tf.train.Saver() if os.path.exists(model): diff --git a/classifiers/train_classifiers.py b/classifiers/train_classifiers.py index 9af0c7944cea7b3b84816db934b9978591bc4bf2..0e9089ee3b504003d40e7314e2b65e032a24a765 100644 --- a/classifiers/train_classifiers.py +++ b/classifiers/train_classifiers.py @@ -18,9 +18,9 @@ workflow = { 'eval_obj': False, 'train_atr': False, 'eval_atr': False, - 'train_rel': True, + 'train_rel': False, 'eval_rel': False, - 'train_ans': False, + 'train_ans': True, } obj_classifier_train_params = { @@ -87,7 +87,7 @@ rel_classifier_eval_params = { 'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier_Obj_Atr_Prob', 'model_basedir': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier_Obj_Atr_Prob', 'model_number': 4, - 'mode': 'q_obj_atr_reg', + 'mode': 'q_obj_atr', 'batch_size': 20, 'test_start_id': 94645, 'test_set_size': 143495-94645+1, @@ -100,15 +100,15 @@ ans_classifier_train_params = { 'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images', 'image_regions_dir': '/mnt/ramdisk/image_regions', 'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier_w_Rel', - 'rel_model': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier/rel_classifier-4', + 'rel_model': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier/rel_classifier_q_obj_atr-4', 'obj_atr_model': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/obj_atr_classifier-1', - 'adam_lr' : 0.001, + 'adam_lr' : 0.0001, 'mode' : 'q_obj_atr', 'crop_n_save_regions': False, - 'max_epoch': 5, + 'max_epoch': 10, 'batch_size': 10, - 'fine_tune': False, - 'start_model': 2, + 'fine_tune': True, + 'start_model': 4, } if __name__=='__main__':