From 47787cce58fcec966843f335a7d3c426cab5ff3d Mon Sep 17 00:00:00 2001
From: tgupta6 <tgupta6@illinois.edu>
Date: Sun, 10 Apr 2016 20:46:45 -0500
Subject: [PATCH] Adding region relevance training and eval scripts

---
 .../answer_classifier/ans_data_io_helper.py   |   8 +
 .../answer_classifier/train_ans_classifier.py |   1 +
 .../region_ranker/eval_rel_classifier.py      |  72 +++++
 .../region_ranker/train_rel_classifier.py     | 303 ++++++++++++++++++
 classifiers/tf_graph_creation_helper.py       |  86 ++++-
 classifiers/train_classifiers.py              |  44 ++-
 6 files changed, 509 insertions(+), 5 deletions(-)
 create mode 100644 classifiers/region_ranker/eval_rel_classifier.py
 create mode 100644 classifiers/region_ranker/train_rel_classifier.py

diff --git a/classifiers/answer_classifier/ans_data_io_helper.py b/classifiers/answer_classifier/ans_data_io_helper.py
index c63f446..e6ab0c3 100644
--- a/classifiers/answer_classifier/ans_data_io_helper.py
+++ b/classifiers/answer_classifier/ans_data_io_helper.py
@@ -184,6 +184,14 @@ class batch_creator():
         return region_images, ans_labels, question_encodings, \
             region_score, partition
 
+    def reshape_score(self, region_score):
+        num_cols = num_proposals
+        num_rows = region_score.shape[1]/num_proposals
+        assertion_str = 'Number of proposals and batch size do not match ' + \
+                        'dimension of region_score'
+        assert (num_cols*num_rows==region_score.shape[1]), assertion_str
+        
+        return np.reshape(region_score,[num_rows, num_cols],'C')
 
 class html_ans_table_writer():
     def __init__(self, filename):
diff --git a/classifiers/answer_classifier/train_ans_classifier.py b/classifiers/answer_classifier/train_ans_classifier.py
index 44992bb..5bd1022 100644
--- a/classifiers/answer_classifier/train_ans_classifier.py
+++ b/classifiers/answer_classifier/train_ans_classifier.py
@@ -192,6 +192,7 @@ def train(train_params):
     image_regions, questions, keep_prob, y, region_score= \
         graph_creator.placeholder_inputs_ans(len(vocab), len(ans_vocab), 
                                              mode='gt')
+
     y_pred_obj = graph_creator.obj_comp_graph(image_regions, 1.0)
     obj_feat_op = g.get_operation_by_name('obj/conv2/obj_feat')
     obj_feat = obj_feat_op.outputs[0]
diff --git a/classifiers/region_ranker/eval_rel_classifier.py b/classifiers/region_ranker/eval_rel_classifier.py
new file mode 100644
index 0000000..be4c701
--- /dev/null
+++ b/classifiers/region_ranker/eval_rel_classifier.py
@@ -0,0 +1,72 @@
+import sys
+import os
+import json
+import math
+import matplotlib.pyplot as plt
+import matplotlib.image as mpimg
+import numpy as np
+import pdb
+import tensorflow as tf
+import answer_classifier.ans_data_io_helper as ans_io_helper
+import region_ranker.perfect_ranker as region_proposer
+import region_ranker.train_rel_classifier as rel_trainer
+import tf_graph_creation_helper as graph_creator
+import plot_helper as plotter
+
+def eval(eval_params):
+    sess = tf.InteractiveSession()
+    train_anno_filename = eval_params['train_json']
+    test_anno_filename = eval_params['test_json']
+    regions_anno_filename = eval_params['regions_json']
+    image_regions_dir = eval_params['image_regions_dir']
+    outdir = eval_params['outdir']
+    model = eval_params['model']
+    batch_size = eval_params['batch_size']
+    test_start_id = eval_params['test_start_id']
+    test_set_size = eval_params['test_set_size']
+
+    if not os.path.exists(outdir):
+        os.mkdir(outdir)
+
+    qa_anno_dict_train = ans_io_helper.parse_qa_anno(train_anno_filename)
+    qa_anno_dict = ans_io_helper.parse_qa_anno(test_anno_filename)
+    region_anno_dict = region_proposer.parse_region_anno(regions_anno_filename)
+    ans_vocab, inv_ans_vocab = ans_io_helper.create_ans_dict()
+    vocab, inv_vocab = ans_io_helper.get_vocab(qa_anno_dict_train)
+
+    
+    # Create graph
+    g = tf.get_default_graph()
+    image_regions, questions, y, keep_prob = \
+        graph_creator.placeholder_inputs_rel(ans_io_helper.num_proposals,
+                                             len(vocab), mode='gt')
+    placeholders = [image_regions, questions, y, keep_prob]
+    y_pred = graph_creator.rel_comp_graph(image_regions, questions, 
+                                          keep_prob, len(vocab), batch_size)
+
+    # Restore model
+    restorer = tf.train.Saver()
+    if os.path.exists(model):
+        restorer.restore(sess, model)
+    else:
+        print 'Failed to read model from file ' + model
+
+    # Load mean image
+    mean_image = np.load('/home/tanmay/Code/GenVQA/Exp_Results/' + \
+                         'Obj_Classifier/mean_image.npy')
+
+    # Batch creator
+    test_batch_creator = ans_io_helper.batch_creator(test_start_id,
+                                                     test_start_id 
+                                                     + test_set_size - 1)
+
+    # Test Recall
+    test_recall = rel_trainer.evaluate(y_pred, qa_anno_dict, 
+                                       region_anno_dict, ans_vocab, 
+                                       vocab, image_regions_dir, 
+                                       mean_image, test_start_id, 
+                                       test_set_size, batch_size,
+                                       placeholders, 75, 75,
+                                       test_batch_creator,verbose=True)
+
+    print('Test Rec: ' + str(test_recall))
diff --git a/classifiers/region_ranker/train_rel_classifier.py b/classifiers/region_ranker/train_rel_classifier.py
new file mode 100644
index 0000000..d93db95
--- /dev/null
+++ b/classifiers/region_ranker/train_rel_classifier.py
@@ -0,0 +1,303 @@
+import sys
+import os
+import json
+import math
+import matplotlib.pyplot as plt
+import matplotlib.image as mpimg
+import numpy as np
+import pdb
+import tensorflow as tf
+import answer_classifier.ans_data_io_helper as ans_io_helper
+import region_ranker.perfect_ranker as region_proposer
+import tf_graph_creation_helper as graph_creator
+import plot_helper as plotter
+
+val_start_id = 89645
+val_set_size = 5000
+val_set_size_small = 500
+
+def recall(pred_scores, gt_scores, k):
+    inc_order = np.argsort(pred_scores, 0)
+    dec_order = inc_order[::-1]
+    gt_scores_ordered = gt_scores[dec_order]
+    rel_reg_recalled = np.sum(gt_scores_ordered[0:k]!=0)
+    rel_reg = np.sum(gt_scores!=0)
+    return rel_reg_recalled/(rel_reg+0.00001)
+
+
+def batch_recall(pred_scores, gt_scores, k):
+    batch_size = pred_scores.shape[0]
+    batch_recall = 0.0
+    for i in xrange(batch_size):
+        if k==-1:
+            k_ = np.sum(gt_scores[i,:]!=0)
+        else:
+            k_ = k
+        batch_recall += recall(pred_scores[i,:], gt_scores[i,:], k_)
+
+    batch_recall = batch_recall/batch_size
+
+    return batch_recall
+
+def evaluate(region_score_pred, qa_anno_dict, region_anno_dict, ans_vocab, vocab,
+             image_dir, mean_image, start_index, val_set_size, batch_size,
+             placeholders, img_height, img_width, batch_creator, verbose=False):
+    
+    recall_at_k = 0
+    max_iter = int(math.floor(val_set_size/batch_size))
+    for i in xrange(max_iter):
+        if verbose==True:
+            print('Iter: ' + str(i+1) + '/' + str(max_iter))
+        region_images, ans_labels, questions, \
+        region_score_vec, partition= batch_creator \
+            .ans_mini_batch_loader(qa_anno_dict, region_anno_dict, 
+                                   ans_vocab, vocab, image_dir, mean_image, 
+                                   start_index+i*batch_size, batch_size, 
+                                   img_height, img_width, 3)
+        region_score = batch_creator.reshape_score(region_score_vec)
+
+        feed_dict = {
+            placeholders[0] : region_images, 
+            placeholders[1] : questions,
+            placeholders[2] : region_score,
+            placeholders[3] : 1.0,
+        }
+
+        region_score_pred_eval = region_score_pred.eval(feed_dict)
+    
+        recall_at_k += batch_recall(region_score_pred_eval, 
+                                    region_score, 1)
+        
+    recall_at_k /= max_iter
+
+    return recall_at_k
+
+
+def train(train_params):
+    sess = tf.InteractiveSession()
+    train_anno_filename = train_params['train_json']
+    test_anno_filename = train_params['test_json']
+    regions_anno_filename = train_params['regions_json']
+    image_dir = train_params['image_dir']
+    image_regions_dir = train_params['image_regions_dir']
+    outdir = train_params['outdir']
+    batch_size = train_params['batch_size']
+
+    if not os.path.exists(outdir):
+        os.mkdir(outdir)
+
+    qa_anno_dict = ans_io_helper.parse_qa_anno(train_anno_filename)
+    region_anno_dict = region_proposer.parse_region_anno(regions_anno_filename)
+    ans_vocab, inv_ans_vocab = ans_io_helper.create_ans_dict()
+    vocab, inv_vocab = ans_io_helper.get_vocab(qa_anno_dict)
+
+    # Save region crops
+    if train_params['crop_n_save_regions'] == True:
+        qa_anno_dict_test = ans_io_helper.parse_qa_anno(test_anno_filename)
+        ans_io_helper.save_regions(image_dir, image_regions_dir,
+                                   qa_anno_dict, region_anno_dict,
+                                   1, 94644, 75, 75)
+        ans_io_helper.save_regions(image_dir, image_regions_dir,
+                                   qa_anno_dict_test, region_anno_dict,
+                                   94645, 143495-94645+1, 75, 75)
+
+    
+    # Create graph
+    g = tf.get_default_graph()
+    image_regions, questions, y, keep_prob = \
+        graph_creator.placeholder_inputs_rel(ans_io_helper.num_proposals,
+                                             len(vocab), mode='gt')
+    placeholders = [image_regions, questions, y, keep_prob]
+    y_pred = graph_creator.rel_comp_graph(image_regions, questions, 
+                                          keep_prob, len(vocab), batch_size)
+
+    accuracy = graph_creator.evaluation(y, y_pred)
+    
+    cross_entropy = graph_creator.loss(y, y_pred)
+
+    # Collect variables
+    params_varnames = [
+        'rel/word_embed/word_vecs',
+        'rel/conv1/W',
+        'rel/conv2/W',
+        'rel/conv1/b',
+        'rel/conv2/b',
+        'rel/fc1/W_reg',
+        'rel/fc1/W_q',
+        'rel/fc1/b',
+        'rel/fc2/W',
+        'rel/fc2/b',
+    ]
+
+    vars_dict = graph_creator.get_list_of_variables(params_varnames)
+
+    # parameters grouped together
+    rel_word_params = [
+        vars_dict['rel/word_embed/word_vecs'],
+    ]
+    
+    rel_conv_params = [
+        vars_dict['rel/conv1/W'],
+        vars_dict['rel/conv2/W'],
+    ]
+ 
+    rel_fc_params = [
+        vars_dict['rel/fc1/W_reg'],
+        vars_dict['rel/fc1/W_q'],
+        vars_dict['rel/fc2/W'],
+    ]
+
+    vars_to_save = [vars_dict[key] for key in vars_dict.keys()]
+    vars_to_train = vars_to_save[:]
+    pretrained_vars = []
+
+    # Regularization
+    regularizer_rel_word_vecs = graph_creator.regularize_params(rel_word_params)
+    regularizer_rel_filters = graph_creator.regularize_params(rel_conv_params)
+    regularizer_rel_fcs = graph_creator.regularize_params(rel_fc_params)
+    
+    total_loss = cross_entropy + \
+                 1e-4 * regularizer_rel_word_vecs + \
+                 1e-3 * regularizer_rel_filters + \
+                 1e-4 * regularizer_rel_fcs
+
+    # Model to restore weights from
+    pretrained_model = os.path.join(outdir, 'rel_classifier-' + \
+                                    str(train_params['start_model']))
+    model_saver = tf.train.Saver(vars_to_save)
+    
+    if train_params['fine_tune']==True:
+        assert (os.path.exists(pretrained_model)), \
+            'Pretrained model does not exist'
+        model_saver.restore(sess, pretrained_model)
+        pretrained_vars = vars_to_save[:]
+        start_epoch = train_params['start_model'] + 1
+    else:
+        start_epoch = 0
+
+    # Attach optimization ops
+    train_step = tf.train.AdamOptimizer(train_params['adam_lr']) \
+                         .minimize(total_loss, var_list=vars_to_train)
+
+    # Initialize uninitialized vars
+    all_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
+    vars_to_init = [var for var in all_vars if var not in pretrained_vars]
+    sess.run(tf.initialize_variables(vars_to_init))
+
+    # Load mean image
+    mean_image = np.load('/home/tanmay/Code/GenVQA/Exp_Results/' + \
+                         'Obj_Classifier/mean_image.npy')
+
+    # Start Training
+    max_epoch = train_params['max_epoch']
+    max_iter = 4400*2
+    val_rec_array_epoch = np.zeros([max_epoch])
+    train_rec_array_epoch = np.zeros([max_epoch])
+
+    # Batch creators
+    train_batch_creator = ans_io_helper.batch_creator(1, max_iter*batch_size)
+    val_batch_creator = ans_io_helper.batch_creator(val_start_id, val_start_id 
+                                                    + val_set_size - 1)
+    val_small_batch_creator = ans_io_helper.batch_creator(val_start_id, 
+                                                          val_start_id + 
+                                                          val_set_size_small-1)
+
+    # Check accuracy of restored model
+    if train_params['fine_tune']==True:
+        restored_recall = evaluate(y_pred, qa_anno_dict, 
+                                   region_anno_dict, ans_vocab, 
+                                   vocab, image_regions_dir, 
+                                   mean_image, val_start_id, 
+                                   val_set_size, batch_size,
+                                   placeholders, 75, 75,
+                                   val_batch_creator)
+        print('Recall of restored model: ' + str(restored_recall))
+    
+    # Accuracy filename
+    train_recall_txtfile = os.path.join(outdir,'train_recall.txt')
+    val_recall_txtfile = os.path.join(outdir,'val_recall.txt')
+                 
+    for epoch in range(start_epoch, max_epoch):
+        train_batch_creator.shuffle_ids()
+        for i in range(max_iter):
+        
+            train_region_images, train_ans_labels, train_questions, \
+            train_region_score_vec, train_partition= train_batch_creator \
+                .ans_mini_batch_loader(qa_anno_dict, region_anno_dict, 
+                                       ans_vocab, vocab, 
+                                       image_regions_dir, mean_image, 
+                                       1+i*batch_size, batch_size, 
+                                       75, 75, 3)
+            train_region_score = train_batch_creator \
+                .reshape_score(train_region_score_vec)
+
+            feed_dict_train = {
+                image_regions : train_region_images, 
+                questions: train_questions,
+                keep_prob: 0.5,
+                y: train_region_score,
+            }
+            
+            _, current_train_batch_acc, y_pred_eval, loss_eval = \
+                    sess.run([train_step, accuracy, y_pred, total_loss], 
+                             feed_dict=feed_dict_train)
+
+            # print y_pred_eval[0,1:]
+            # print train_region_score[0,1:]
+            assert (not np.any(np.isnan(y_pred_eval))), 'NaN predicted'
+
+            
+            train_rec_array_epoch[epoch] = train_rec_array_epoch[epoch] + \
+                                           batch_recall(y_pred_eval, 
+                                                        train_region_score, 5)
+        
+            if (i+1)%500==0:
+                val_recall = evaluate(y_pred, qa_anno_dict, 
+                                      region_anno_dict, ans_vocab, vocab,
+                                      image_regions_dir, mean_image, 
+                                      val_start_id, val_set_size_small,
+                                      batch_size, placeholders, 75, 75,
+                                      val_small_batch_creator)
+                
+                print('Iter: ' + str(i+1) + ' Val Sm Rec: ' + str(val_recall))
+
+        train_rec_array_epoch[epoch] = train_rec_array_epoch[epoch] / max_iter
+        val_rec_array_epoch[epoch] = evaluate(y_pred, qa_anno_dict, 
+                                              region_anno_dict, ans_vocab, 
+                                              vocab, image_regions_dir, 
+                                              mean_image, val_start_id, 
+                                              val_set_size, batch_size,
+                                              placeholders, 75, 75,
+                                              val_batch_creator)
+
+        print('Val Rec: ' + str(val_rec_array_epoch[epoch]) + 
+              ' Train Rec: ' + str(train_rec_array_epoch[epoch]))
+        
+        
+        if train_params['fine_tune']==True:
+            plot_path  = os.path.join(outdir, 'rec_vs_epoch_fine_tuned.pdf')
+        else:
+            plot_path = os.path.join(outdir, 'rec_vs_epoch.pdf')
+
+        plotter.write_accuracy_to_file(start_epoch, epoch, 
+                                       train_rec_array_epoch,
+                                       train_params['fine_tune'],
+                                       train_recall_txtfile)
+        plotter.write_accuracy_to_file(start_epoch, epoch, 
+                                       val_rec_array_epoch,
+                                       train_params['fine_tune'],
+                                       val_recall_txtfile)
+        plotter.plot_accuracies(xdata=np.arange(0, epoch + 1) + 1,
+                                ydata_train=train_rec_array_epoch[0:epoch + 1], 
+                                ydata_val=val_rec_array_epoch[0:epoch + 1], 
+                                xlim=[1, max_epoch], ylim=[0, 1.0], 
+                                savePath=plot_path)
+
+        save_path = model_saver.save(sess, 
+                                     os.path.join(outdir, 'rel_classifier'), 
+                                     global_step=epoch)
+
+    sess.close()
+    tf.reset_default_graph()
+
+
diff --git a/classifiers/tf_graph_creation_helper.py b/classifiers/tf_graph_creation_helper.py
index 6f02552..0e832d3 100644
--- a/classifiers/tf_graph_creation_helper.py
+++ b/classifiers/tf_graph_creation_helper.py
@@ -12,12 +12,16 @@ graph_config = {
     'region_feat_dim': 392, #3136
     'word_vec_dim': 50,
     'ans_fc1_dim': 300,
+    'rel_fc1_dim': 100,
 }
 
 def get_variable(var_scope):
     var_list = tf.get_collection(tf.GraphKeys.VARIABLES, scope=var_scope)
     
-    assert len(var_list)==1, 'Multiple variables exist by that name'
+    assert_str = 'No variable exists by that name: ' + var_scope
+    assert len(var_list)!=0, assert_str
+    assert_str = 'Multiple variables exist by that name: ' + var_scope
+    assert len(var_list)==1, assert_str
     
     return var_list[0]
 
@@ -48,7 +52,8 @@ def conv2d(x, W, var_name = 'conv'):
 
 
 def max_pool_2x2(x, var_name = 'h_pool'):
-    return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME', name=var_name)
+    return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], 
+                          padding='SAME', name=var_name)
 
 
 def placeholder_inputs(mode = 'gt'):
@@ -63,6 +68,20 @@ def placeholder_inputs(mode = 'gt'):
         return (x, keep_prob)
     
 
+def placeholder_inputs_rel(num_proposals, total_vocab_size, mode = 'gt'):
+    image_regions = tf.placeholder(tf.float32, shape=[None,25,25,3])
+    keep_prob = tf.placeholder(tf.float32)
+    questions = tf.placeholder(tf.float32, shape=[None,total_vocab_size])
+    if mode == 'gt':
+        print 'Creating placeholder for ground truth'
+        y = tf.placeholder(tf.float32, 
+                           shape=[None, ans_io_helper.num_proposals])
+        return (image_regions, questions, y, keep_prob)
+    if mode == 'no_gt':
+        print 'No placeholder for ground truth'
+        return (image_regions, questions, keep_prob)
+
+
 def placeholder_inputs_ans(total_vocab_size, ans_vocab_size, mode='gt'):
     image_regions = tf.placeholder(tf.float32, shape=[None,25,25,3])
     keep_prob = tf.placeholder(tf.float32)
@@ -162,6 +181,69 @@ def atr_comp_graph(x, keep_prob, obj_feat):
 
     return y_pred
 
+def rel_comp_graph(image_regions, questions, keep_prob, vocab_size, batch_size):
+
+    with tf.name_scope('rel') as rel_graph:
+
+        with tf.name_scope('word_embed') as q_embed:
+            word_vecs = weight_variable([vocab_size,
+                                         graph_config['word_vec_dim']],
+                                        var_name='word_vecs')
+            q_feat = tf.matmul(questions, word_vecs, name='q_feat')
+
+        with tf.name_scope('conv1') as conv1:
+            W_conv1 = weight_variable([5,5,3,4])
+            b_conv1 = bias_variable([4])
+            a_conv1 = tf.add(conv2d(image_regions, W_conv1), b_conv1, name='a')
+            h_conv1 = tf.nn.relu(a_conv1, name='h')
+            h_pool1 = max_pool_2x2(h_conv1)
+            h_conv1_drop = tf.nn.dropout(h_pool1, keep_prob, name='h_pool_drop')
+
+        with tf.name_scope('conv2') as conv2:
+            W_conv2 = weight_variable([3,3,4,8])
+            b_conv2 = bias_variable([8])
+            a_conv2 = tf.add(conv2d(h_pool1, W_conv2), b_conv2, name='a')
+            h_conv2 = tf.nn.relu(a_conv2, name='h')
+            h_pool2 = max_pool_2x2(h_conv2)
+            h_pool2_drop = tf.nn.dropout(h_pool2, keep_prob, name='h_pool_drop')
+            h_pool2_drop_shape = h_pool2_drop.get_shape()
+            reg_feat_dim = reduce(lambda f, g: f*g, 
+                                  [dim.value for dim in h_pool2_drop_shape[1:]])
+            reg_feat = tf.reshape(h_pool2_drop, [-1, reg_feat_dim], 
+                                  name='reg_feat')
+
+            print('Relevance region feature dimension: ' + str(reg_feat_dim)) 
+
+        with tf.name_scope('fc1') as fc1:
+            fc1_dim = graph_config['rel_fc1_dim']
+            W_reg_fc1 = weight_variable([reg_feat_dim, fc1_dim], 
+                                        var_name='W_reg')
+            W_q_fc1 = weight_variable([graph_config['word_vec_dim'], 
+                                       fc1_dim], var_name='W_q')
+            b_fc1 = bias_variable([fc1_dim])
+            
+            a_reg_fc1 = tf.matmul(reg_feat, W_reg_fc1, name='a_reg_fc1')
+            a_q_fc1 = tf.matmul(q_feat, W_q_fc1, name='a_q_fc1')
+
+            a_fc1 = 0.5*a_reg_fc1 + 0.5*a_q_fc1 + b_fc1
+
+            h_fc1 = tf.nn.relu(a_fc1, name='h')
+            h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob, name='h_drop')
+
+        with tf.name_scope('fc2') as fc2:
+            W_fc2 = weight_variable([fc1_dim, 1])
+            b_fc2 = bias_variable([1])
+            
+            vec_logits = tf.add(tf.matmul(h_fc1_drop, W_fc2), b_fc2, 
+                                name='vec_logits')
+
+            logits = tf.reshape(vec_logits,
+                                [batch_size, ans_io_helper.num_proposals])
+            
+        y_pred = tf.nn.softmax(logits, name='softmax')
+
+    return y_pred
+
 
 def ans_comp_graph(image_regions, questions, keep_prob, obj_feat, atr_feat, 
                    vocab, inv_vocab, ans_vocab_size, mode):
diff --git a/classifiers/train_classifiers.py b/classifiers/train_classifiers.py
index f9af94e..d4f89bb 100644
--- a/classifiers/train_classifiers.py
+++ b/classifiers/train_classifiers.py
@@ -10,17 +10,19 @@ import object_classifiers.eval_obj_classifier as obj_evaluator
 import attribute_classifiers.train_atr_classifier as atr_trainer
 import attribute_classifiers.eval_atr_classifier as atr_evaluator
 import answer_classifier.train_ans_classifier as ans_trainer
+import region_ranker.train_rel_classifier as rel_trainer
+import region_ranker.eval_rel_classifier as rel_evaluator
 
 workflow = {
     'train_obj': False,
     'eval_obj': False,
     'train_atr': False,
     'eval_atr': False,
-    'train_ans': True,
+    'train_ans': False,
+    'train_rel': False,
+    'eval_rel': True,
 }
 
-ans_mode = ['q']
-
 obj_classifier_train_params = {
     'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier',
     'adam_lr': 0.001,
@@ -76,6 +78,35 @@ ans_classifier_train_params = {
     'start_model': 4,
 }
 
+rel_classifier_train_params = {
+    'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
+    'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
+    'regions_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/regions_anno.json',
+    'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
+    'image_regions_dir': '/mnt/ramdisk/image_regions',
+    'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier',
+    'obj_atr_model': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/obj_atr_classifier-1',
+    'adam_lr' : 0.0001,
+    'crop_n_save_regions': False,
+    'max_epoch': 5,
+    'batch_size': 10,
+    'fine_tune': True,
+    'start_model': 1,
+}
+
+rel_classifier_eval_params = {
+    'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
+    'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
+    'regions_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/regions_anno.json',
+    'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
+    'image_regions_dir': '/mnt/ramdisk/image_regions',
+    'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier',
+    'model': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier/rel_classifier-4',
+    'batch_size': 20,
+    'test_start_id': 94645,
+    'test_set_size': 143495-94645+1,
+}
+
 if __name__=='__main__':
     if workflow['train_obj']:
         obj_trainer.train(obj_classifier_train_params)
@@ -89,5 +120,12 @@ if __name__=='__main__':
     if workflow['eval_atr']:
         atr_evaluator.eval(atr_classifier_eval_params)
 
+    if workflow['train_rel']:
+        rel_trainer.train(rel_classifier_train_params)
+
+    if workflow['eval_rel']:
+        rel_evaluator.eval(rel_classifier_eval_params)
+
     if workflow['train_ans']:
         ans_trainer.train(ans_classifier_train_params)
+
-- 
GitLab