diff --git a/classifiers/answer_classifier/ans_data_io_helper.py b/classifiers/answer_classifier/ans_data_io_helper.py index 58988558996a7751a51c370b4ba0cc3cc1c21592..5c55cde21a97e166b11d2a40851b184f336346e6 100644 --- a/classifiers/answer_classifier/ans_data_io_helper.py +++ b/classifiers/answer_classifier/ans_data_io_helper.py @@ -1,6 +1,7 @@ import json import sys import os +import time import matplotlib.pyplot as plt import matplotlib.image as mpimg import numpy as np @@ -13,7 +14,7 @@ import region_ranker.perfect_ranker as region_proposer qa_tuple = namedtuple('qa_tuple','image_id question answer') num_proposals = 22 -region_coords = region_proposer.get_region_coords() +region_coords, region_coords_ = region_proposer.get_region_coords(75, 75) def create_ans_dict(): @@ -65,7 +66,40 @@ def get_vocab(qa_dict): inv_vocab = {v: k for k, v in vocab.items()} return vocab, inv_vocab + + +def save_regions(image_dir, out_dir, qa_dict, region_anno_dict, start_id, + batch_size, img_width, img_height): + + print('Saving Regions ...') + if not os.path.exists(out_dir): + os.mkdir(out_dir) + + region_shape = np.array([img_height/3, img_width/3], np.int32) + + image_done = np.zeros(shape=[batch_size], dtype=bool) + for i in xrange(start_id, start_id + batch_size): + i + image_id = qa_dict[i].image_id + if image_done[image_id-1]==False: + gt_regions_for_image = region_anno_dict[image_id] + image = mpimg.imread(os.path.join(image_dir, + str(image_id) + '.jpg')) + resized_image = misc.imresize(image, (img_height, img_width)) + regions = region_proposer.rank_regions(resized_image, None, + region_coords, + region_coords_, + gt_regions_for_image) + for j in xrange(num_proposals): + filename = os.path.join(out_dir, '{}_{}.png'.format(image_id,j)) + resized_region = misc.imresize(regions[j].image, + (region_shape[0], + region_shape[1])) + misc.imsave(filename,resized_region) + image_done[image_id-1] = True + + def ans_mini_batch_loader(qa_dict, region_anno_dict, ans_dict, vocab, image_dir, mean_image, start_index, batch_size, img_height=100, img_width=100, channels = 3): @@ -74,52 +108,56 @@ def ans_mini_batch_loader(qa_dict, region_anno_dict, ans_dict, vocab, for i in xrange(start_index, start_index + batch_size): answer = qa_dict[i].answer ans_labels[i-start_index, ans_dict[answer]] = 1 - + # number of regions in the batch count = batch_size*num_proposals; - - region_images = np.zeros(shape=[count, img_height, - img_width, channels]) + region_shape = np.array([img_height/3, img_width/3], np.int32) + region_images = np.zeros(shape=[count, region_shape[0], + region_shape[1], channels]) region_score = np.zeros(shape=[1,count]) partition = np.zeros(shape=[count]) question_encodings = np.zeros(shape=[count, len(vocab)]) for i in xrange(start_index, start_index + batch_size): + image_id = qa_dict[i].image_id question = qa_dict[i].question answer = qa_dict[i].answer gt_regions_for_image = region_anno_dict[image_id] - image = mpimg.imread(os.path.join(image_dir, - str(image_id) + '.jpg')) - regions = region_proposer.rank_regions(image, question, region_coords, - gt_regions_for_image) + start1 = time.time() + regions = region_proposer.rank_regions(None, question, + region_coords, region_coords_, + gt_regions_for_image, + False) + + end1 = time.time() +# print('Ranking Region: ' + str(end1-start1)) + + for word in question[0:-1].split(): + if word not in vocab: + word = 'unk' + question_encodings[0, vocab[word]] += 1 + for j in xrange(num_proposals): counter = j + (i-start_index)*num_proposals - resized_region = misc.imresize(regions[j].image, \ - (img_height, img_width)) + proposal = regions[j] + + start2 = time.time() + resized_region = mpimg.imread(os.path.join(image_dir, + '{}_{}.png' + .format(image_id,j))) + end2 = time.time() +# print('Reading Region: ' + str(end2-start2)) region_images[counter,:,:,:] = (resized_region / 254.0) \ - mean_image - region_score[0,counter] = regions[j].score + region_score[0,counter] = proposal.score partition[counter] = i-start_index + + question_encodings[counter,:] = question_encodings[0,:] - for word in question[0:-1].split(): - if word not in vocab: - word = 'unk' - question_encodings[counter, vocab[word]] += 1 - - # Check for nans, infs - assert (not np.any(np.isnan(region_images))), "NaN in region_images" - assert (not np.any(np.isnan(ans_labels))), "NaN in labels" - assert (not np.any(np.isnan(question_encodings))), "NaN in question_encodings" - assert (not np.any(np.isnan(region_score))), "NaN in region_score" - assert (not np.any(np.isnan(partition))), "NaN in partition" - - assert (not np.any(np.isinf(region_images))), "Inf in region_images" - assert (not np.any(np.isinf(ans_labels))), "Inf in labels" - assert (not np.any(np.isinf(question_encodings))), "Inf in question_encodings" - assert (not np.any(np.isinf(region_score))), "Inf in region_score" - assert (not np.any(np.isinf(partition))), "Inf in partition" - + score_start_id = (i-start_index)*num_proposals + region_score[0, score_start_id:score_start_id+num_proposals] /= \ + np.sum(region_score[0,score_start_id:score_start_id+num_proposals]) return region_images, ans_labels, question_encodings, region_score, partition diff --git a/classifiers/answer_classifier/ans_data_io_helper.pyc b/classifiers/answer_classifier/ans_data_io_helper.pyc index 6eda2294ef5263df1f1e80c298cf9d93111a0b99..9ed62f9ffe32ff02671645bc90a02302313167b1 100644 Binary files a/classifiers/answer_classifier/ans_data_io_helper.pyc and b/classifiers/answer_classifier/ans_data_io_helper.pyc differ diff --git a/classifiers/answer_classifier/train_ans_classifier.py b/classifiers/answer_classifier/train_ans_classifier.py index 6d27300500b1ab60033e7fd188fd3141331e217b..1bf795b27b041730981e253ac09d8614aa21cf8b 100644 --- a/classifiers/answer_classifier/train_ans_classifier.py +++ b/classifiers/answer_classifier/train_ans_classifier.py @@ -4,6 +4,7 @@ import json import matplotlib.pyplot as plt import matplotlib.image as mpimg import numpy as np +import math import tensorflow as tf import object_classifiers.obj_data_io_helper as obj_data_loader import attribute_classifiers.atr_data_io_helper as atr_data_loader @@ -11,25 +12,28 @@ import tf_graph_creation_helper as graph_creator import plot_helper as plotter import ans_data_io_helper as ans_io_helper import region_ranker.perfect_ranker as region_proposer - +import time val_start_id = 106115 -val_batch_size = 1000 - -batch_size = 10 - +val_batch_size = 5000 +val_batch_size_small = 100 +batch_size = 20 +crop_n_save_regions = False +restore_intermediate_model = True def evaluate(accuracy, qa_anno_dict, region_anno_dict, ans_vocab, vocab, - image_dir, mean_image, start_index, batch_size, + image_dir, mean_image, start_index, val_batch_size, placeholders, img_height=100, img_width=100): correct = 0 - for i in xrange(start_index, start_index + batch_size): + max_iter = int(math.floor(val_batch_size/batch_size)) + for i in xrange(max_iter): region_images, ans_labels, questions, \ region_score, partition= \ ans_io_helper.ans_mini_batch_loader(qa_anno_dict, region_anno_dict, ans_vocab, vocab, image_dir, mean_image, - i, 1, + start_index+i*batch_size, + batch_size, img_height, img_width, 3) feed_dict = { @@ -42,7 +46,7 @@ def evaluate(accuracy, qa_anno_dict, region_anno_dict, ans_vocab, vocab, correct = correct + accuracy.eval(feed_dict) - return correct/batch_size + return correct/max_iter def train(train_params): @@ -51,12 +55,20 @@ def train(train_params): train_anno_filename = '/home/tanmay/Code/GenVQA/GenVQA/' + \ 'shapes_dataset/train_anno.json' + test_anno_filename = '/home/tanmay/Code/GenVQA/GenVQA/' + \ + 'shapes_dataset/test_anno.json' + regions_anno_filename = '/home/tanmay/Code/GenVQA/GenVQA/' + \ 'shapes_dataset/regions_anno.json' image_dir = '/home/tanmay/Code/GenVQA/GenVQA/' + \ 'shapes_dataset/images' + # image_regions_dir = '/home/tanmay/Code/GenVQA/Exp_Results/' + \ + # 'image_regions' + + image_regions_dir = '/mnt/ramdisk/image_regions' + outdir = '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier' if not os.path.exists(outdir): os.mkdir(outdir) @@ -66,28 +78,39 @@ def train(train_params): ans_vocab, inv_ans_vocab = ans_io_helper.create_ans_dict() vocab, inv_vocab = ans_io_helper.get_vocab(qa_anno_dict) + # Save region crops + if crop_n_save_regions == True: + qa_anno_dict_test = ans_io_helper.parse_qa_anno(test_anno_filename) + ans_io_helper.save_regions(image_dir, image_regions_dir, + qa_anno_dict, region_anno_dict, + 1, 111351, 75, 75) + ans_io_helper.save_regions(image_dir, image_regions_dir, + qa_anno_dict_test, region_anno_dict, + 111352, 160725-111352+1, 75, 75) + + # Create graph image_regions, questions, keep_prob, y, region_score= \ graph_creator.placeholder_inputs_ans(len(vocab), len(ans_vocab), mode='gt') - y_pred_obj = graph_creator.obj_comp_graph(image_regions, keep_prob) + y_pred_obj = graph_creator.obj_comp_graph(image_regions, 1.0) obj_feat = tf.get_collection('obj_feat', scope='obj/conv2') - y_pred_atr = graph_creator.atr_comp_graph(image_regions, keep_prob, obj_feat[0]) + y_pred_atr = graph_creator.atr_comp_graph(image_regions, 1.0, obj_feat[0]) atr_feat = tf.get_collection('atr_feat', scope='atr/conv2') # model restoration obj_atr_saver = tf.train.Saver() model_to_restore = '/home/tanmay/Code/GenVQA/GenVQA/classifiers/' + \ 'saved_models/obj_atr_classifier-1' - obj_atr_saver.restore(sess, model_to_restore) +# obj_atr_saver.restore(sess, model_to_restore) - y_pred = graph_creator.ans_comp_graph(image_regions, questions, keep_prob, \ - obj_feat[0], atr_feat[0], - vocab, inv_vocab, len(ans_vocab)) + y_pred, logits = graph_creator.ans_comp_graph(image_regions, + questions, keep_prob, \ + obj_feat[0], atr_feat[0], + vocab, inv_vocab, len(ans_vocab)) y_avg = graph_creator.aggregate_y_pred(y_pred, region_score, batch_size, ans_io_helper.num_proposals, len(ans_vocab)) -# y_avg = tf.matmul(region_score,y_pred) cross_entropy = graph_creator.loss(y, y_avg) accuracy = graph_creator.evaluation(y, y_avg) @@ -97,21 +120,34 @@ def train(train_params): train_step = tf.train.AdamOptimizer(train_params['adam_lr']) \ .minimize(cross_entropy, var_list=vars_to_opt) - - vars_to_restore = [] - vars_to_restore.append(tf.get_collection(tf.GraphKeys.VARIABLES, - scope='obj')) - vars_to_restore.append(tf.get_collection(tf.GraphKeys.VARIABLES, - scope='atr')) + print(train_step.name) + vars_to_restore = tf.get_collection(tf.GraphKeys.VARIABLES,scope='obj') + \ + tf.get_collection(tf.GraphKeys.VARIABLES, scope='atr') + \ + tf.get_collection(tf.GraphKeys.VARIABLES, scope='ans/word_embed') all_vars = tf.get_collection(tf.GraphKeys.VARIABLES) - vars_to_init = [var for var in all_vars if var not in vars_to_restore] + vars_to_init = [var for var in all_vars if var not in vars_to_restore] + # Session saver - saver = tf.train.Saver() - - # Initializing all variables except those restored - print('Initializing variables') - sess.run(tf.initialize_variables(vars_to_init)) + + saver = tf.train.Saver(vars_to_restore) + + if restore_intermediate_model==True: + intermediate_model = '/home/tanmay/Code/GenVQA/Exp_Results/' + \ + 'Ans_Classifier/ans_classifier_question_only-9' + print('vars_to_restore: ') + print([var.name for var in vars_to_restore]) + print('vars_to_init: ') + print([var.name for var in vars_to_init]) + saver.restore(sess, intermediate_model) +# print('Initializing variables') + sess.run(tf.initialize_variables(vars_to_init)) + start_epoch = 0 + else: + # Initializing all variables except those restored + print('Initializing variables') + sess.run(tf.initialize_variables(vars_to_init)) + start_epoch = 0 # Load mean image mean_image = np.load('/home/tanmay/Code/GenVQA/Exp_Results/' + \ @@ -122,29 +158,21 @@ def train(train_params): # Start Training # batch_size = 1 max_epoch = 10 - max_iter = 9500 + max_iter = 5000 val_acc_array_epoch = np.zeros([max_epoch]) train_acc_array_epoch = np.zeros([max_epoch]) - for epoch in range(max_epoch): + for epoch in range(start_epoch, max_epoch): + start = time.time() for i in range(max_iter): - if i%100==0: - print('Iter: ' + str(i)) - - # val_accuracy = evaluate(accuracy, qa_anno_dict, - # region_anno_dict, ans_vocab, vocab, - # image_dir, mean_image, - # val_start_id, val_batch_size, - # placeholders, 25, 25) - # print(val_accuracy) - + train_region_images, train_ans_labels, train_questions, \ train_region_score, train_partition= \ ans_io_helper.ans_mini_batch_loader(qa_anno_dict, region_anno_dict, - ans_vocab, vocab, image_dir, - mean_image, 1+i*batch_size, - batch_size, 25, 25, 3) - - + ans_vocab, vocab, + image_regions_dir, mean_image, + 1+i*batch_size, batch_size, + 75, 75, 3) + feed_dict_train = { image_regions : train_region_images, questions: train_questions, @@ -152,38 +180,61 @@ def train(train_params): y: train_ans_labels, region_score: train_region_score, } - - tf.shape(y_pred) - - q_feat = tf.get_collection('q_feat', scope='ans/q_embed') - _,current_train_batch_acc,q_feat_eval = \ - sess.run([train_step, accuracy, q_feat[0]], + start1 = time.time() + _, current_train_batch_acc, y_avg_eval, y_pred_eval, logits_eval = \ + sess.run([train_step, accuracy, y_avg, y_pred, logits], feed_dict=feed_dict_train) - -# print(q_feat_eval) - # print(q_feat_eval.shape) -# print(i) -# print(train_questions) -# print(train_ans_labels) -# print(train_region_score) - + end1 = time.time() + # print('Training Pass: ' + str(end1-start1)) train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] + \ current_train_batch_acc + try: + assert (not np.any(np.isnan(y_avg_eval))) + except AssertionError: + print('Run NaNs coming') + print(1+i*batch_size) + print(y_avg_eval) + exit(1) + + if (i+1)%500==0: + print(logits_eval[0:22,:]) + print(train_region_score[0,0:22]) +# print(train_ans_labels[0,:]) + print(y_avg_eval[0,:]) +# print(y_pred_eval) + val_accuracy = evaluate(accuracy, qa_anno_dict, + region_anno_dict, ans_vocab, vocab, + image_regions_dir, mean_image, + val_start_id, val_batch_size_small, + placeholders, 75, 75) + + print('Iter: ' + str(i+1) + ' Val Sm Acc: ' + str(val_accuracy)) + + end = time.time() + print('Per Iter Time: ' + str(end-start)) train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] / max_iter - # val_accuracy = evaluate(accuracy, qa_anno_dict, - # region_anno_dict, ans_vocab, vocab, - # image_dir, mean_image, 9501, 499, - # placeholders, 25, 25) - # val_acc_array_epoch[epoch] = val_accuracy - # print(val_accuracy) - # plotter.plot_accuracies(xdata=np.arange(0, epoch + 1) + 1, - # ydata_train=train_acc_array_epoch[0:epoch + 1], - # ydata_val=val_acc_array_epoch[0:epoch + 1], - # xlim=[1, max_epoch], ylim=[0, 1.0], - # savePath=os.path.join(outdir, - # 'acc_vs_epoch.pdf')) - save_path = saver.save(sess, os.path.join(outdir, 'ans_classifier'), + start = time.time() + val_acc_array_epoch[epoch] = evaluate(accuracy, qa_anno_dict, + region_anno_dict, ans_vocab, vocab, + image_regions_dir, mean_image, + val_start_id, val_batch_size, + placeholders, 75, 75) + end=time.time() + print('Per Validation Time: ' + str(end-start)) + print('Val Acc: ' + str(val_acc_array_epoch[epoch]) + + ' Train Acc: ' + str(train_acc_array_epoch[epoch])) + + + plotter.plot_accuracies(xdata=np.arange(0, epoch + 1) + 1, + ydata_train=train_acc_array_epoch[0:epoch + 1], + ydata_val=val_acc_array_epoch[0:epoch + 1], + xlim=[1, max_epoch], ylim=[0, 1.0], + savePath=os.path.join(outdir, + 'acc_vs_epoch_q_o_atr.pdf')) + + save_path = saver.save(sess, + os.path.join(outdir,'ans_classifier_question_obj_atr_only'), global_step=epoch) sess.close() @@ -191,6 +242,6 @@ def train(train_params): if __name__=='__main__': train_params = { - 'adam_lr' : 0.001, + 'adam_lr' : 0.00001, } train(train_params) diff --git a/classifiers/region_ranker/perfect_ranker.py b/classifiers/region_ranker/perfect_ranker.py index a26d8e159d3a3c0611a5d32937e39cc39a943acf..9d44e14f6c3f988f2e4504ba9b0569554071339f 100644 --- a/classifiers/region_ranker/perfect_ranker.py +++ b/classifiers/region_ranker/perfect_ranker.py @@ -1,5 +1,6 @@ import json import os +import math import numpy as np from collections import namedtuple import matplotlib.pyplot as plt @@ -19,8 +20,8 @@ def parse_region_anno(json_filename): return region_anno_dict -def get_region_coords(): - region_coords = np.array([[ 1, 1, 100, 100], +def get_region_coords(img_height, img_width): + region_coords_ = np.array([[ 1, 1, 100, 100], [ 101, 1, 200, 100], [ 201, 1, 300, 100], [ 1, 101, 100, 200], @@ -42,9 +43,17 @@ def get_region_coords(): [ 1, 201, 200, 300], [ 101, 201, 300, 300], [ 1, 1, 300, 300]]) - return region_coords -def rank_regions(image, question, region_coords, gt_regions_for_image): + region_coords = np.copy(region_coords_) + region_coords[:,0] = np.ceil(region_coords[:,0]*img_width/300.0) + region_coords[:,1] = np.ceil(region_coords[:,1]*img_height/300.0) + region_coords[:,2] = np.round(region_coords[:,2]*img_width/300.0) + region_coords[:,3] = np.round(region_coords[:,3]*img_height/300.0) + print(region_coords) + return region_coords, region_coords_ + +def rank_regions(image, question, region_coords, region_coords_, + gt_regions_for_image, crop=True): num_regions, _ = region_coords.shape regions = dict() @@ -52,17 +61,27 @@ def rank_regions(image, question, region_coords, gt_regions_for_image): count = 0; no_region_flag = True for i in xrange(num_regions): + x1_ = region_coords_[i,0] + y1_ = region_coords_[i,1] + x2_ = region_coords_[i,2] + y2_ = region_coords_[i,3] + x1 = region_coords[i,0] y1 = region_coords[i,1] x2 = region_coords[i,2] y2 = region_coords[i,3] - - cropped_image = image[y1-1:y2, x1-1:x2, :] + + if crop: + cropped_image = image[y1-1:y2, x1-1:x2, :] + else: + cropped_image = None + score = 0 for gt_region in gt_regions_for_image: - x1_, y1_, x2_, y2_ = gt_regions_for_image[gt_region] - if x1==x1_ and x2==x2_ and y1==y1_ and y2==y2_: + gt_x1, gt_y1, gt_x2, gt_y2 = gt_regions_for_image[gt_region] + if gt_x1==x1_ and gt_x2==x2_ and gt_y1==y1_ and \ + gt_y2==y2_ and gt_region in question: score = 1 no_region_flag = False break diff --git a/classifiers/region_ranker/perfect_ranker.pyc b/classifiers/region_ranker/perfect_ranker.pyc index ac7c5aca66999387ddddd8c1307cd88b3c37391b..6807d1410576a03a277f3b0cf87f780fe3fe750f 100644 Binary files a/classifiers/region_ranker/perfect_ranker.pyc and b/classifiers/region_ranker/perfect_ranker.pyc differ diff --git a/classifiers/tf_graph_creation_helper.py b/classifiers/tf_graph_creation_helper.py index 1b18cc04f8761e8f7e20b0ffed0c044627280c98..be5d85448d0d3e55c750d9e8d37fed703982919a 100644 --- a/classifiers/tf_graph_creation_helper.py +++ b/classifiers/tf_graph_creation_helper.py @@ -1,14 +1,15 @@ import numpy as np +import math import tensorflow as tf import answer_classifier.ans_data_io_helper as ans_io_helper -def weight_variable(shape, var_name = 'W'): - initial = tf.truncated_normal(shape, stddev=0.1) +def weight_variable(shape, var_name = 'W', std=0.1): + initial = tf.truncated_normal(shape, stddev=std) return tf.Variable(initial, name=var_name) def bias_variable(shape, var_name = 'b'): - initial = tf.constant(0.1, shape=shape) + initial = tf.constant(0.001, shape=shape) return tf.Variable(initial, name=var_name) @@ -96,13 +97,14 @@ def ans_comp_graph(image_regions, questions, keep_prob, \ obj_feat, atr_feat, vocab, inv_vocab, ans_vocab_size): with tf.name_scope('ans') as ans_graph: with tf.name_scope('word_embed') as word_embed: - initial = tf.truncated_normal(shape=[len(vocab),100], stddev=0.1) -# initial = tf.random_uniform(shape=[len(vocab),100], minval=0, maxval=1) + initial = tf.truncated_normal(shape=[len(vocab),50], + stddev=math.sqrt(3.0/(31.0+300.0))) word_vecs = tf.Variable(initial, name='word_vecs') with tf.name_scope('q_embed') as q_embed: q_feat = tf.matmul(questions, word_vecs) -# q_feat = tf.truediv(q_feat, tf.cast(len(vocab),tf.float32)) + # q_feat = tf.truediv(q_feat, tf.cast(len(vocab),tf.float32)) + q_feat = tf.truediv(q_feat, tf.reduce_sum(questions,1,keep_dims=True)) with tf.name_scope('conv1') as conv1: W_conv1 = weight_variable([5,5,3,4]) @@ -120,21 +122,38 @@ def ans_comp_graph(image_regions, questions, keep_prob, \ h_pool2_drop_flat = tf.reshape(h_pool2_drop, [-1, 392], name='h_pool_drop_flat') with tf.name_scope('fc1') as fc1: - W_region_fc1 = weight_variable([392, ans_vocab_size], var_name='W_region') - W_obj_fc1 = weight_variable([392, ans_vocab_size], var_name='W_obj') - W_atr_fc1 = weight_variable([392, ans_vocab_size], var_name='W_atr') - W_q_fc1 = weight_variable([100, ans_vocab_size], var_name='W_q') - b_fc1 = bias_variable([ans_vocab_size]) - - y_pred = tf.nn.softmax(tf.matmul(h_pool2_drop_flat, W_region_fc1) + \ - tf.matmul(obj_feat, W_obj_fc1) + \ - tf.matmul(atr_feat, W_atr_fc1) + \ - tf.matmul(q_feat, W_q_fc1) + b_fc1) + fc1_dim = 300 + W_region_fc1 = weight_variable([392, fc1_dim], var_name='W_region') + W_obj_fc1 = weight_variable([392, fc1_dim], var_name='W_obj', + std=math.sqrt(3/(392+ans_vocab_size))) + W_atr_fc1 = weight_variable([392, fc1_dim], var_name='W_atr', + std=math.sqrt(3/(392+ans_vocab_size))) + W_q_fc1 = weight_variable([50, fc1_dim], var_name='W_q', + std=math.sqrt(3.0/(50.0+ans_vocab_size))) + b_fc1 = bias_variable([fc1_dim]) + + h_tmp = tf.matmul(q_feat, W_q_fc1) + b_fc1 + \ + tf.matmul(obj_feat, W_obj_fc1) + \ + tf.matmul(atr_feat, W_atr_fc1) + + #tf.matmul(h_pool2_drop_flat, W_region_fc1) + \ + + h_fc1 = tf.nn.relu(h_tmp, name='h') + h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob, name='h_drop') + + with tf.name_scope('fc2') as fc2: + W_fc2 = weight_variable([fc1_dim, ans_vocab_size], + std=math.sqrt(3.0/(fc1_dim))) + b_fc2 = bias_variable([ans_vocab_size]) + + logits = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 + + y_pred = tf.nn.softmax(logits) tf.add_to_collection('region_feat', h_pool2_drop_flat) tf.add_to_collection('q_feat', q_feat) - return y_pred + return y_pred, logits def aggregate_y_pred(y_pred, region_score, batch_size, num_proposals, ans_vocab_size): y_pred_list = tf.split(0, batch_size, y_pred) @@ -153,7 +172,8 @@ def evaluation(y, y_pred): def loss(y, y_pred): cross_entropy = -tf.reduce_sum(y * tf.log(y_pred), name='cross_entropy') - return cross_entropy + batch_size = tf.shape(y) + return tf.truediv(cross_entropy, tf.cast(batch_size[0],tf.float32)) if __name__ == '__main__': diff --git a/classifiers/tf_graph_creation_helper.pyc b/classifiers/tf_graph_creation_helper.pyc index cb966f7932fe3d93f724fcdb8c59eb25570ee614..4defda24febbdc5564e8f0d5b4820a16d9ee72e0 100644 Binary files a/classifiers/tf_graph_creation_helper.pyc and b/classifiers/tf_graph_creation_helper.pyc differ