From fe74071740907757623f26fe93b0122166c02554 Mon Sep 17 00:00:00 2001
From: tgupta6 <tgupta6@illinois.edu>
Date: Thu, 28 Apr 2016 12:28:37 -0500
Subject: [PATCH] add placeholder for question bins and containment

---
 .../answer_classifier/ans_data_io_helper.py   | 124 ++++++++++++++++--
 .../answer_classifier/train_ans_classifier.py | 105 ++++++++-------
 classifiers/tf_graph_creation_helper.py       |  67 ++++++++--
 classifiers/train_classifiers.py              |  11 +-
 4 files changed, 231 insertions(+), 76 deletions(-)

diff --git a/classifiers/answer_classifier/ans_data_io_helper.py b/classifiers/answer_classifier/ans_data_io_helper.py
index e6ab0c3..50ace34 100644
--- a/classifiers/answer_classifier/ans_data_io_helper.py
+++ b/classifiers/answer_classifier/ans_data_io_helper.py
@@ -51,6 +51,17 @@ def parse_qa_anno(json_filename):
     return qa_dict
                                                  
 
+def read_parsed_questions(json_filename):
+    with open(json_filename, 'r') as json_file:
+        raw_data = json.load(json_file)
+
+    parsed_q_dict = dict()
+    for entry in raw_data:
+        parsed_q_dict[entry['question_id']] = entry['question_parse']
+
+    return parsed_q_dict
+
+
 def get_vocab(qa_dict):
     vocab = dict()
     count = 0;
@@ -124,7 +135,8 @@ class batch_creator():
 
     def ans_mini_batch_loader(self, qa_dict, region_anno_dict, ans_dict, vocab, 
                               image_dir, mean_image, start_index, batch_size, 
-                              img_height=100, img_width=100, channels = 3):
+                              parsed_q_dict, img_height=100, img_width=100, 
+                              channels = 3):
 
         q_ids = self.qa_index(start_index, batch_size)
 
@@ -141,8 +153,9 @@ class batch_creator():
                                         region_shape[1], channels])
         region_score = np.zeros(shape=[1,count])
         partition = np.zeros(shape=[count])
-        question_encodings = np.zeros(shape=[count, len(vocab)])
-            
+        parsed_q = dict()
+#        question_encodings = np.zeros(shape=[count, len(vocab)])
+
         for i in xrange(batch_size):
             q_id = q_ids[i]
             image_id = qa_dict[q_id].image_id
@@ -155,18 +168,19 @@ class batch_creator():
                                                    gt_regions_for_image,
                                                    False)
             
-            question_encoding_tmp = np.zeros(shape=[1, len(vocab)])
-            for word in question[0:-1].split():
-                if word.lower() not in vocab:
-                    word = 'unk'
-                question_encoding_tmp[0, vocab[word.lower()]] += 1 
+            # question_encoding_tmp = np.zeros(shape=[1, len(vocab)])
+            # for word in question[0:-1].split():
+            #     if word.lower() not in vocab:
+            #         word = 'unk'
+            #     question_encoding_tmp[0, vocab[word.lower()]] += 1 
                     
-            question_len = np.sum(question_encoding_tmp)
-            assert (not question_len==0)
-            question_encoding_tmp /= question_len
+            # question_len = np.sum(question_encoding_tmp)
+            # assert (not question_len==0)
+            # question_encoding_tmp /= question_len
         
             for j in xrange(num_proposals):
                 counter = j + i*num_proposals
+                parsed_q[counter] = parsed_q_dict[q_id]
                 proposal = regions[j]
                 resized_region = mpimg.imread(os.path.join(image_dir, 
                                              '{}_{}.png'.format(image_id,j)))
@@ -175,14 +189,18 @@ class batch_creator():
                 region_score[0,counter] = proposal.score
                 partition[counter] = i
                 
-                question_encodings[counter,:] = question_encoding_tmp
+#                question_encodings[counter,:] = question_encoding_tmp
                 
             score_start_id = i*num_proposals
             region_score[0, score_start_id:score_start_id+num_proposals] /=\
-                np.sum(region_score[0,score_start_id
+                    np.sum(region_score[0,score_start_id
                                         : score_start_id+num_proposals])
-        return region_images, ans_labels, question_encodings, \
+
+        return region_images, ans_labels, parsed_q, \
             region_score, partition
+        # return region_images, ans_labels, question_encodings, \
+        #     region_score, partition
+        
 
     def reshape_score(self, region_score):
         num_cols = num_proposals
@@ -193,6 +211,84 @@ class batch_creator():
         
         return np.reshape(region_score,[num_rows, num_cols],'C')
 
+ 
+obj_labels = {
+    0: 'blank',
+    1: 'square',
+    2: 'triangle',
+    3: 'circle',
+}   
+
+
+atr_labels = {
+    0: 'red', 
+    1: 'green',
+    2: 'blue',
+    3: 'blank',
+}
+
+
+class feed_dict_creator():
+    def __init__(self, region_images, ans_labels, parsed_q, 
+                 region_score, keep_prob, plholder_dict, vocab):
+        self.plholder_dict = plholder_dict
+        self.parsed_q = parsed_q
+        self.vocab = vocab
+        self.max_words = 5
+        self.feed_dict = {
+            plholder_dict['image_regions']: region_images,
+            plholder_dict['keep_prob']: keep_prob,
+            plholder_dict['gt_answer']: ans_labels,
+            plholder_dict['region_score']: region_score,
+        }
+        self.add_bin('bin0')
+        self.add_bin('bin1')
+        self.add_bin('bin2')
+        self.add_bin('bin3')
+        for i in xrange(4):
+            bin_name = 'bin' + str(i)
+            self.label_bin_containment(bin_name, obj_labels, 'obj')
+            self.label_bin_containment(bin_name, atr_labels, 'atr')
+
+    def add_bin(self, bin_name):
+        num_q = len(self.parsed_q)
+        shape_list = [num_q, len(self.vocab)]
+        indices_list = []
+        values_list = []
+        for q_num in xrange(num_q):
+            item = self.parsed_q[q_num]
+            word_list = item[bin_name]
+            num_words = len(word_list)
+            assert_str = 'number of bin words exceeded limit'
+            assert (num_words <= self.max_words), assert_str
+            for word_num, word in enumerate(word_list):
+                if word=='':
+                    word = 'unk'
+                indices_list.append((q_num, word_num))
+                values_list.append(self.vocab[word.lower()])
+
+        # convert to numpy arrays
+        shape = np.asarray(shape_list)
+        indices = np.asarray(indices_list)
+        values = np.asarray(values_list)
+        self.feed_dict[self.plholder_dict[bin_name + '_indices']] = indices
+        self.feed_dict[self.plholder_dict[bin_name + '_values']] = values
+        self.feed_dict[self.plholder_dict[bin_name + '_shape']] = shape
+
+    def label_bin_containment(self, bin_name, labels, label_type):
+        num_q = len(self.parsed_q)
+        num_labels = len(labels)
+        containment = np.zeros([num_q, num_labels], dtype='float32')
+        for q_num in xrange(num_q):
+            for i, label in labels.items():
+                if label in [pq.lower() for pq in self.parsed_q[q_num][bin_name]]:
+                    containment[q_num,i] = 1
+
+        plholder = self.plholder_dict[bin_name + '_' + \
+                                      label_type + '_' + 'cont']
+        self.feed_dict[plholder] = containment
+        
+
 class html_ans_table_writer():
     def __init__(self, filename):
         self.filename = filename
diff --git a/classifiers/answer_classifier/train_ans_classifier.py b/classifiers/answer_classifier/train_ans_classifier.py
index 6d4213e..d1dcc3b 100644
--- a/classifiers/answer_classifier/train_ans_classifier.py
+++ b/classifiers/answer_classifier/train_ans_classifier.py
@@ -133,25 +133,24 @@ def get_process_flow_vars(mode, obj_vars, atr_vars, rel_vars, fine_tune):
 
 def evaluate(accuracy, qa_anno_dict, region_anno_dict, ans_vocab, vocab,
              image_dir, mean_image, start_index, val_set_size, batch_size,
-             placeholders, img_height, img_width, batch_creator):
+             plholder_dict, img_height, img_width, batch_creator, 
+             parsed_q_dict):
     
     correct = 0
     max_iter = int(math.floor(val_set_size/batch_size))
     for i in xrange(max_iter):
-        region_images, ans_labels, questions, \
+        region_images, ans_labels, parsed_q, \
         region_score, partition= batch_creator \
             .ans_mini_batch_loader(qa_anno_dict, region_anno_dict, 
                                    ans_vocab, vocab, image_dir, mean_image, 
                                    start_index+i*batch_size, batch_size, 
+                                   parsed_q_dict, 
                                    img_height, img_width, 3)
             
-        feed_dict = {
-            placeholders[0] : region_images, 
-            placeholders[1] : questions,
-            placeholders[2] : 1.0,
-            placeholders[3] : ans_labels,        
-            placeholders[4] : region_score,
-        }
+        feed_dict = ans_io_helper.\
+                    feed_dict_creator(region_images, ans_labels, parsed_q, 
+                                      region_score, 1.0, plholder_dict, 
+                                      vocab).feed_dict
 
         correct = correct + accuracy.eval(feed_dict)
 
@@ -163,6 +162,7 @@ def train(train_params):
     
     train_anno_filename = train_params['train_json']
     test_anno_filename = train_params['test_json']
+    parsed_q_filename = train_params['parsed_q_json']
     regions_anno_filename = train_params['regions_json']
     image_dir = train_params['image_dir']
     image_regions_dir = train_params['image_regions_dir']
@@ -175,6 +175,7 @@ def train(train_params):
         os.mkdir(outdir)
 
     qa_anno_dict = ans_io_helper.parse_qa_anno(train_anno_filename)
+    parsed_q_dict = ans_io_helper.read_parsed_questions(parsed_q_filename)
     region_anno_dict = region_proposer.parse_region_anno(regions_anno_filename)
     ans_vocab, inv_ans_vocab = ans_io_helper.create_ans_dict()
     vocab, inv_vocab = ans_io_helper.get_vocab(qa_anno_dict)
@@ -192,43 +193,51 @@ def train(train_params):
     
     # Create graph
     g = tf.get_default_graph()
-    image_regions, questions, keep_prob, y, region_score= \
-        graph_creator.placeholder_inputs_ans(len(vocab), len(ans_vocab), 
-                                             mode='gt')
+    plholder_dict = graph_creator.placeholder_inputs_ans(len(vocab), 
+                                                         len(ans_vocab), 
+                                                         mode='gt')
+    image_regions = plholder_dict['image_regions']
+    questions = plholder_dict['questions']
+    keep_prob = plholder_dict['keep_prob']
+    y = plholder_dict['gt_answer']
+    region_score = plholder_dict['region_score']
+
     y_pred_obj = graph_creator.obj_comp_graph(image_regions, 1.0)
     obj_feat_op = g.get_operation_by_name('obj/conv2/obj_feat')
     obj_feat = obj_feat_op.outputs[0]
     y_pred_atr = graph_creator.atr_comp_graph(image_regions, 1.0, obj_feat)
     atr_feat_op = g.get_operation_by_name('atr/conv2/atr_feat')
     atr_feat = atr_feat_op.outputs[0]
+
     # pred_rel_score = graph_creator.rel_comp_graph(image_regions, questions,
-    #                                               y_pred_obj, y_pred_atr, 
+    #                                               obj_feat, atr_feat, 
     #                                               'q_obj_atr_reg', 1.0, 
-    #                                               len(vocab), batch_size) 
-    pred_rel_score = graph_creator.rel_comp_graph(image_regions, questions,
-                                                  obj_feat, atr_feat, 
-                                                  'q_obj_atr_reg', 1.0, 
-                                                  len(vocab), batch_size) 
+    #                                               len(vocab), batch_size)
+    
 
     # Restore rel, obj and attribute classifier parameters
-    rel_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='rel')
+#    rel_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='rel')
     obj_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='obj')
     atr_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='atr')
 
-    rel_saver = tf.train.Saver(rel_vars)
+ #   rel_saver = tf.train.Saver(rel_vars)
     obj_atr_saver = tf.train.Saver(obj_vars+atr_vars)
 
-    rel_saver.restore(sess, rel_model)
+  #  rel_saver.restore(sess, rel_model)
     obj_atr_saver.restore(sess, obj_atr_model)
 
-    y_pred = graph_creator.ans_comp_graph(image_regions, questions, keep_prob, 
+    y_pred = graph_creator.ans_comp_graph(plholder_dict, 
                                           obj_feat, atr_feat, vocab, 
                                           inv_vocab, len(ans_vocab), 
                                           train_params['mode'])
-    pred_rel_score_vec = tf.reshape(pred_rel_score, 
-                                    [1, batch_size*ans_io_helper.num_proposals])
+#    pred_rel_score_vec = tf.reshape(pred_rel_score, 
+    #                                 [1, batch_size*ans_io_helper.num_proposals])
+    # y_avg = graph_creator.aggregate_y_pred(y_pred, 
+    #                                        pred_rel_score_vec, batch_size, 
+    #                                        ans_io_helper.num_proposals, 
+    #                                        len(ans_vocab))
     y_avg = graph_creator.aggregate_y_pred(y_pred, 
-                                           pred_rel_score_vec, batch_size, 
+                                           region_score, batch_size, 
                                            ans_io_helper.num_proposals, 
                                            len(ans_vocab))
     
@@ -240,7 +249,7 @@ def train(train_params):
     pretrained_vars, vars_to_train, vars_to_restore, vars_to_save, \
         vars_to_init, vars_dict = \
             get_process_flow_vars(train_params['mode'], 
-                                  obj_vars, atr_vars, rel_vars,
+                                  obj_vars, atr_vars, [], #rel_vars,
                                   train_params['fine_tune'])
 
     # Regularizers
@@ -295,7 +304,8 @@ def train(train_params):
         partial_restorer = tf.train.Saver(vars_to_restore)
     else:
         start_epoch = 0
-        partial_restorer = tf.train.Saver(pretrained_vars)
+        if train_params['mode']!='q':
+            partial_restorer = tf.train.Saver(pretrained_vars)
 
     # Restore partial model
 #    partial_restorer = tf.train.Saver(vars_to_restore)
@@ -313,7 +323,7 @@ def train(train_params):
     # Initialize vars_to_init
     all_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
     optimizer_vars = [var for var in all_vars if var not in \
-                      obj_vars + atr_vars + rel_vars + ans_vars]
+                      obj_vars + atr_vars + ans_vars] #rel_vars + ans_vars]
     
     print('Optimizer Variables: ')
     print([var.name for var in optimizer_vars])
@@ -347,8 +357,9 @@ def train(train_params):
                                      vocab, image_regions_dir, 
                                      mean_image, val_start_id, 
                                      val_set_size, batch_size,
-                                     placeholders, 75, 75,
-                                     val_batch_creator)
+                                     plholder_dict, 75, 75,
+                                     val_batch_creator,
+                                     parsed_q_dict)
         print('Accuracy of restored model: ' + str(restored_accuracy))
     
     # Accuracy filename
@@ -360,23 +371,25 @@ def train(train_params):
     for epoch in range(start_epoch, max_epoch):
         train_batch_creator.shuffle_ids()
         for i in range(max_iter):
-        
-            train_region_images, train_ans_labels, train_questions, \
+            train_region_images, train_ans_labels, train_parsed_q, \
             train_region_score, train_partition= train_batch_creator \
                 .ans_mini_batch_loader(qa_anno_dict, region_anno_dict, 
                                        ans_vocab, vocab, 
                                        image_regions_dir, mean_image, 
-                                       1+i*batch_size, batch_size, 
+                                       1+i*batch_size, batch_size,
+                                       parsed_q_dict,
                                        75, 75, 3)
-                        
-            feed_dict_train = {
-                image_regions : train_region_images, 
-                questions: train_questions,
-                keep_prob: 0.5,
-                y: train_ans_labels,        
-                region_score: train_region_score,
-            }
+
+            feed_dict_train = ans_io_helper \
+                .feed_dict_creator(train_region_images, 
+                                   train_ans_labels, 
+                                   train_parsed_q,
+                                   train_region_score,
+                                   0.5, 
+                                   plholder_dict,
+                                   vocab).feed_dict
             
+
             _, current_train_batch_acc, y_avg_eval, loss_eval = \
                     sess.run([train_step, accuracy, y_avg, total_loss], 
                              feed_dict=feed_dict_train)
@@ -394,8 +407,9 @@ def train(train_params):
                                         region_anno_dict, ans_vocab, vocab,
                                         image_regions_dir, mean_image, 
                                         val_start_id, val_set_size_small,
-                                        batch_size, placeholders, 75, 75,
-                                        val_small_batch_creator)
+                                        batch_size, plholder_dict, 75, 75,
+                                        val_small_batch_creator,
+                                        parsed_q_dict)
                 
                 print('Iter: ' + str(i+1) + ' Val Sm Acc: ' + str(val_accuracy))
 
@@ -405,8 +419,9 @@ def train(train_params):
                                               vocab, image_regions_dir, 
                                               mean_image, val_start_id, 
                                               val_set_size, batch_size,
-                                              placeholders, 75, 75,
-                                              val_batch_creator)
+                                              plholder_dict, 75, 75,
+                                              val_batch_creator,
+                                              parsed_q_dict)
 
         print('Val Acc: ' + str(val_acc_array_epoch[epoch]) + 
               ' Train Acc: ' + str(train_acc_array_epoch[epoch]))
diff --git a/classifiers/tf_graph_creation_helper.py b/classifiers/tf_graph_creation_helper.py
index 34cebf6..a60f2e2 100644
--- a/classifiers/tf_graph_creation_helper.py
+++ b/classifiers/tf_graph_creation_helper.py
@@ -11,6 +11,7 @@ graph_config = {
     'atr_feat_dim': 392,
     'region_feat_dim': 392, #3136
     'word_vec_dim': 50,
+    'q_embed_dim': 200,
     'ans_fc1_dim': 300,
     'rel_fc1_dim': 100,
 }
@@ -83,18 +84,40 @@ def placeholder_inputs_rel(num_proposals, total_vocab_size, mode = 'gt'):
 
 
 def placeholder_inputs_ans(total_vocab_size, ans_vocab_size, mode='gt'):
-    image_regions = tf.placeholder(tf.float32, shape=[None,25,25,3])
-    keep_prob = tf.placeholder(tf.float32)
-    questions = tf.placeholder(tf.float32, shape=[None,total_vocab_size])
-    region_score = tf.placeholder(tf.float32, shape=[1,None])
-    
+    plholder_dict = {
+        'image_regions': tf.placeholder(tf.float32, [None,25,25,3], 
+                                        'image_regions'),
+        'keep_prob': tf.placeholder(tf.float32, name='keep_prob'),
+        'questions': tf.placeholder(tf.float32, [None,total_vocab_size], 
+                                    'questions'),
+        'region_score': tf.placeholder(tf.float32, [1,None], 
+                                       'region_score'),
+    }
+    for i in xrange(4):
+        bin_name = 'bin' + str(i)
+        plholder_dict[bin_name + '_shape'] = \
+            tf.placeholder(tf.int64, [2], bin_name + '_shape')
+        plholder_dict[bin_name + '_indices'] = \
+            tf.placeholder(tf.int64, [None, 2], bin_name + '_indices')
+        plholder_dict[bin_name + '_values'] = \
+            tf.placeholder(tf.int64, [None], bin_name + '_values')
+        plholder_dict[bin_name + '_obj_cont'] = \
+            tf.placeholder(tf.float32, [None, graph_config['num_objects']],
+                           bin_name + '_obj_cont')
+        plholder_dict[bin_name + '_atr_cont'] = \
+            tf.placeholder(tf.float32, [None, graph_config['num_attributes']],
+                           bin_name + '_atr_cont')
+        
     if mode == 'gt':
         print 'Creating placeholder for ground truth'
-        gt_answer = tf.placeholder(tf.float32, shape=[None, ans_vocab_size])
-        return (image_regions, questions, keep_prob, gt_answer, region_score)
+        plholder_dict['gt_answer'] = tf.placeholder(tf.float32, 
+                                                    shape=[None, 
+                                                           ans_vocab_size],
+                                                    name = 'gt_answer')
+        return plholder_dict
     if mode == 'no_gt':
         print 'No placeholder for ground truth'
-        return (image_regions, questions, keep_prob, region_score)
+        return plholder_dict
         
 
 def obj_comp_graph(x, keep_prob):
@@ -285,9 +308,21 @@ def rel_comp_graph(image_regions, questions, obj_feat, atr_feat,
     return y_pred
 
 
-def ans_comp_graph(image_regions, questions, keep_prob, obj_feat, atr_feat, 
+def q_bin_embed_graph(bin_name, word_vecs, plholder_dict):
+    indices = plholder_dict[bin_name + '_indices']
+    values = plholder_dict[bin_name + '_values']
+    shape = plholder_dict[bin_name + '_shape']
+    sp_ids = tf.SparseTensor(indices, values, shape)
+    return tf.nn.embedding_lookup_sparse(word_vecs, sp_ids, None, 
+                                         name=bin_name + '_embedding')
+    
+
+def ans_comp_graph(plholder_dict, obj_feat, atr_feat, 
                    vocab, inv_vocab, ans_vocab_size, mode):
     vocab_size = len(vocab)
+    image_regions = plholder_dict['image_regions']
+    keep_prob = plholder_dict['keep_prob']
+
     with tf.name_scope('ans') as ans_graph:
 
         with tf.name_scope('word_embed') as word_embed:
@@ -295,8 +330,16 @@ def ans_comp_graph(image_regions, questions, keep_prob, obj_feat, atr_feat,
             word_vecs = weight_variable([vocab_size,
                                          graph_config['word_vec_dim']],
                                         var_name='word_vecs')
-            q_feat = tf.matmul(questions, word_vecs, name='q_feat')
-
+            
+            bin0_embed = q_bin_embed_graph('bin0', word_vecs, plholder_dict)
+            bin1_embed = q_bin_embed_graph('bin1', word_vecs, plholder_dict)
+            bin2_embed = q_bin_embed_graph('bin2', word_vecs, plholder_dict)
+            bin3_embed = q_bin_embed_graph('bin3', word_vecs, plholder_dict)
+            q_feat = tf.concat(1, [bin0_embed,
+                                   bin1_embed,
+                                   bin2_embed,
+                                   bin3_embed], name='q_feat')
+            
         with tf.name_scope('conv1') as conv1:
             num_filters_conv1 = 4
             W_conv1 = weight_variable([5,5,3,num_filters_conv1])
@@ -331,7 +374,7 @@ def ans_comp_graph(image_regions, questions, keep_prob, obj_feat, atr_feat,
                                          fc1_dim], var_name='W_obj')
             W_atr_fc1 = weight_variable([graph_config['atr_feat_dim'], 
                                          fc1_dim], var_name='W_atr')
-            W_q_fc1 = weight_variable([graph_config['word_vec_dim'], 
+            W_q_fc1 = weight_variable([graph_config['q_embed_dim'], 
                                        fc1_dim], var_name='W_q')
             b_fc1 = bias_variable([fc1_dim])
 
diff --git a/classifiers/train_classifiers.py b/classifiers/train_classifiers.py
index 0e9089e..015c6eb 100644
--- a/classifiers/train_classifiers.py
+++ b/classifiers/train_classifiers.py
@@ -97,18 +97,19 @@ ans_classifier_train_params = {
     'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
     'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
     'regions_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/regions_anno.json',
+    'parsed_q_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/parsed_questions.json',
     'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
     'image_regions_dir': '/mnt/ramdisk/image_regions',
     'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier_w_Rel',
     'rel_model': '/home/tanmay/Code/GenVQA/Exp_Results/Rel_Classifier/rel_classifier_q_obj_atr-4',
     'obj_atr_model': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/obj_atr_classifier-1',
-    'adam_lr' : 0.0001,
-    'mode' : 'q_obj_atr',
+    'adam_lr' : 0.001,
+    'mode' : 'q',
     'crop_n_save_regions': False,
-    'max_epoch': 10,
+    'max_epoch': 5,
     'batch_size': 10,
-    'fine_tune': True,
-    'start_model': 4,
+    'fine_tune': False,
+    'start_model': 4, # When fine_tune is false used to pre-initialize q_obj_atr with q model etc
 }
 
 if __name__=='__main__':
-- 
GitLab