From bdb5e93dc77f61d79bc74360401fad5cfd5a2301 Mon Sep 17 00:00:00 2001
From: tgupta6 <tgupta6@illinois.edu>
Date: Tue, 12 Apr 2016 21:47:42 -0500
Subject: [PATCH] Answer prediction with region scoring; no obj or atr features
 used for reg scoring

---
 .../answer_classifier/eval_ans_classifier.py  | 17 +++++---
 .../answer_classifier/train_ans_classifier.py | 33 +++++++++++-----
 classifiers/train_classifiers.py              | 39 ++++++++++---------
 3 files changed, 55 insertions(+), 34 deletions(-)

diff --git a/classifiers/answer_classifier/eval_ans_classifier.py b/classifiers/answer_classifier/eval_ans_classifier.py
index 642fd05..b5f6ee6 100644
--- a/classifiers/answer_classifier/eval_ans_classifier.py
+++ b/classifiers/answer_classifier/eval_ans_classifier.py
@@ -113,6 +113,8 @@ def eval(eval_params):
     image_regions, questions, keep_prob, y, region_score= \
         graph_creator.placeholder_inputs_ans(len(vocab), len(ans_vocab), 
                                              mode='gt')
+    pred_rel_score = graph_creator.rel_comp_graph(image_regions, questions,
+                                                  1.0, len(vocab), batch_size) 
     y_pred_obj = graph_creator.obj_comp_graph(image_regions, 1.0)
     obj_feat_op = g.get_operation_by_name('obj/conv2/obj_feat')
     obj_feat = obj_feat_op.outputs[0]
@@ -124,7 +126,10 @@ def eval(eval_params):
                                           obj_feat, atr_feat, vocab, 
                                           inv_vocab, len(ans_vocab), 
                                           eval_params['mode'])
-    y_avg = graph_creator.aggregate_y_pred(y_pred, region_score, batch_size,
+    pred_rel_score_vec = tf.reshape(pred_rel_score, 
+                                    [1, batch_size*ans_io_helper.num_proposals])
+    y_avg = graph_creator.aggregate_y_pred(y_pred, pred_rel_score_vec, 
+                                           batch_size,  
                                            ans_io_helper.num_proposals, 
                                            len(ans_vocab))
     
@@ -132,12 +137,13 @@ def eval(eval_params):
     accuracy = graph_creator.evaluation(y, y_avg)
 
     # Collect variables
+    rel_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='rel')
     obj_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='obj')
     atr_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='atr')
     pretrained_vars, vars_to_train, vars_to_restore, vars_to_save, \
         vars_to_init, vars_dict = ans_trainer \
             .get_process_flow_vars(eval_params['mode'], 
-                                   obj_vars, atr_vars,
+                                   obj_vars, atr_vars, rel_vars,
                                    True)
 
     # Restore model
@@ -257,7 +263,7 @@ def create_html_file(outdir, test_anno_filename, regions_anno_filename,
     
 
 if __name__=='__main__':
-    mode = 'q_reg'
+    mode = 'q_obj_atr_reg'
     model_num = 9
     ans_classifier_eval_params = {
         'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
@@ -265,8 +271,9 @@ if __name__=='__main__':
         'regions_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/regions_anno.json',
         'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
         'image_regions_dir': '/mnt/ramdisk/image_regions',
-        'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier',
-        'model': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier/ans_classifier_' + mode + '-' + str(model_num),
+        'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier_w_Rel',
+        'rel_model': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier/rel_classifier-4',
+        'model': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier_w_Rel/ans_classifier_' + mode + '-' + str(model_num),
         'mode' : mode,
         'batch_size': 20,
         'test_start_id': 94645,
diff --git a/classifiers/answer_classifier/train_ans_classifier.py b/classifiers/answer_classifier/train_ans_classifier.py
index 5bd1022..372a2bf 100644
--- a/classifiers/answer_classifier/train_ans_classifier.py
+++ b/classifiers/answer_classifier/train_ans_classifier.py
@@ -20,7 +20,7 @@ val_start_id = 89645
 val_set_size = 5000
 val_set_size_small = 500
 
-def get_process_flow_vars(mode, obj_vars, atr_vars, fine_tune):
+def get_process_flow_vars(mode, obj_vars, atr_vars, rel_vars, fine_tune):
     list_of_vars = [
         'ans/word_embed/word_vecs',
         'ans/conv1/W',
@@ -94,14 +94,15 @@ def get_process_flow_vars(mode, obj_vars, atr_vars, fine_tune):
 
     # Fine tune begining with a previous model
     if fine_tune==True:
-        vars_to_restore = obj_vars + atr_vars + vars_to_train
+        vars_to_restore = obj_vars + atr_vars + rel_vars + vars_to_train
         if not mode=='q':
             vars_to_restore += [vars_dict['ans/word_embed/word_vecs']]
     else:
-        vars_to_restore = obj_vars + atr_vars + pretrained_vars
+        vars_to_restore = obj_vars + atr_vars + rel_vars + pretrained_vars
+#        vars_to_restore = pretrained_vars
 
     # Save trained vars
-    vars_to_save = obj_vars + atr_vars + vars_to_train + \
+    vars_to_save = obj_vars + atr_vars + rel_vars + vars_to_train + \
                    [vars_dict['ans/word_embed/word_vecs']]
 
     # Initialize vars_to_init
@@ -165,6 +166,7 @@ def train(train_params):
     image_dir = train_params['image_dir']
     image_regions_dir = train_params['image_regions_dir']
     outdir = train_params['outdir']
+    rel_model = train_params['rel_model']
     obj_atr_model = train_params['obj_atr_model']
     batch_size = train_params['batch_size']
 
@@ -192,7 +194,8 @@ def train(train_params):
     image_regions, questions, keep_prob, y, region_score= \
         graph_creator.placeholder_inputs_ans(len(vocab), len(ans_vocab), 
                                              mode='gt')
-
+    pred_rel_score = graph_creator.rel_comp_graph(image_regions, questions,
+                                                  1.0, len(vocab), batch_size) 
     y_pred_obj = graph_creator.obj_comp_graph(image_regions, 1.0)
     obj_feat_op = g.get_operation_by_name('obj/conv2/obj_feat')
     obj_feat = obj_feat_op.outputs[0]
@@ -200,17 +203,25 @@ def train(train_params):
     atr_feat_op = g.get_operation_by_name('atr/conv2/atr_feat')
     atr_feat = atr_feat_op.outputs[0]
 
-    # Restore obj and attribute classifier parameters
+    # Restore rel, obj and attribute classifier parameters
+    rel_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='rel')
     obj_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='obj')
     atr_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='atr')
+
+    rel_saver = tf.train.Saver(rel_vars)
     obj_atr_saver = tf.train.Saver(obj_vars+atr_vars)
+
+    rel_saver.restore(sess, rel_model)
     obj_atr_saver.restore(sess, obj_atr_model)
 
     y_pred = graph_creator.ans_comp_graph(image_regions, questions, keep_prob, 
                                           obj_feat, atr_feat, vocab, 
                                           inv_vocab, len(ans_vocab), 
                                           train_params['mode'])
-    y_avg = graph_creator.aggregate_y_pred(y_pred, region_score, batch_size,
+    pred_rel_score_vec = tf.reshape(pred_rel_score, 
+                                    [1, batch_size*ans_io_helper.num_proposals])
+    y_avg = graph_creator.aggregate_y_pred(y_pred, 
+                                           pred_rel_score_vec, batch_size, 
                                            ans_io_helper.num_proposals, 
                                            len(ans_vocab))
     
@@ -222,7 +233,7 @@ def train(train_params):
     pretrained_vars, vars_to_train, vars_to_restore, vars_to_save, \
         vars_to_init, vars_dict = \
             get_process_flow_vars(train_params['mode'], 
-                                  obj_vars, atr_vars,
+                                  obj_vars, atr_vars, rel_vars,
                                   train_params['fine_tune'])
 
     # Regularizers
@@ -274,11 +285,13 @@ def train(train_params):
                                      train_params['mode'] + '-' + \
                                      str(train_params['start_model']))
         start_epoch = train_params['start_model']+1
+        partial_restorer = tf.train.Saver(vars_to_restore)
     else:
         start_epoch = 0
+        partial_restorer = tf.train.Saver(pretrained_vars)
 
     # Restore partial model
-    partial_restorer = tf.train.Saver(vars_to_restore)
+#    partial_restorer = tf.train.Saver(vars_to_restore)
     if os.path.exists(partial_model):
         partial_restorer.restore(sess, partial_model)
 
@@ -293,7 +306,7 @@ def train(train_params):
     # Initialize vars_to_init
     all_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
     optimizer_vars = [var for var in all_vars if var not in \
-                      obj_vars + atr_vars + ans_vars]
+                      obj_vars + atr_vars + rel_vars + ans_vars]
     
     print('Optimizer Variables: ')
     print([var.name for var in optimizer_vars])
diff --git a/classifiers/train_classifiers.py b/classifiers/train_classifiers.py
index d4f89bb..b95ef4b 100644
--- a/classifiers/train_classifiers.py
+++ b/classifiers/train_classifiers.py
@@ -18,9 +18,9 @@ workflow = {
     'eval_obj': False,
     'train_atr': False,
     'eval_atr': False,
-    'train_ans': False,
     'train_rel': False,
-    'eval_rel': True,
+    'eval_rel': False,
+    'train_ans': True,
 }
 
 obj_classifier_train_params = {
@@ -61,23 +61,6 @@ atr_classifier_eval_params = {
     'html_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/html_dir',
 }
 
-ans_classifier_train_params = {
-    'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
-    'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
-    'regions_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/regions_anno.json',
-    'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
-    'image_regions_dir': '/mnt/ramdisk/image_regions',
-    'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier',
-    'obj_atr_model': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/obj_atr_classifier-1',
-    'adam_lr' : 0.001,
-    'mode' : 'q_reg',
-    'crop_n_save_regions': False,
-    'max_epoch': 10,
-    'batch_size': 10,
-    'fine_tune': False,
-    'start_model': 4,
-}
-
 rel_classifier_train_params = {
     'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
     'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
@@ -107,6 +90,24 @@ rel_classifier_eval_params = {
     'test_set_size': 143495-94645+1,
 }
 
+ans_classifier_train_params = {
+    'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
+    'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
+    'regions_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/regions_anno.json',
+    'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
+    'image_regions_dir': '/mnt/ramdisk/image_regions',
+    'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier_w_Rel',
+    'rel_model': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier/rel_classifier-4',
+    'obj_atr_model': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/obj_atr_classifier-1',
+    'adam_lr' : 0.0001,
+    'mode' : 'q_obj_atr_reg',
+    'crop_n_save_regions': False,
+    'max_epoch': 10,
+    'batch_size': 10,
+    'fine_tune': True,
+    'start_model': 4,
+}
+
 if __name__=='__main__':
     if workflow['train_obj']:
         obj_trainer.train(obj_classifier_train_params)
-- 
GitLab