From f49cbdacceb00b18191770f8bb6243092e12a612 Mon Sep 17 00:00:00 2001 From: tgupta6 <tgupta6@illinois.edu> Date: Mon, 28 Mar 2016 15:35:57 -0500 Subject: [PATCH] Need to debug --- .../answer_classifier/eval_ans_classifier.py | 172 +++++++++ .../answer_classifier/train_ans_classifier.py | 336 ++++++++++-------- .../train_ans_classifier.pyc | Bin 0 -> 8282 bytes classifiers/tf_graph_creation_helper.py | 141 +++++--- classifiers/tf_graph_creation_helper.pyc | Bin 10342 -> 11829 bytes classifiers/train_classifiers.py | 27 +- shapes_dataset/evaluate_shapes_test.py | 2 +- 7 files changed, 481 insertions(+), 197 deletions(-) create mode 100644 classifiers/answer_classifier/eval_ans_classifier.py create mode 100644 classifiers/answer_classifier/train_ans_classifier.pyc diff --git a/classifiers/answer_classifier/eval_ans_classifier.py b/classifiers/answer_classifier/eval_ans_classifier.py new file mode 100644 index 0000000..35eb657 --- /dev/null +++ b/classifiers/answer_classifier/eval_ans_classifier.py @@ -0,0 +1,172 @@ +import sys +import os +import json +import matplotlib.pyplot as plt +import matplotlib.image as mpimg +import numpy as np +import math +import pdb +import tensorflow as tf +import tf_graph_creation_helper as graph_creator +import plot_helper as plotter +import ans_data_io_helper as ans_io_helper +import region_ranker.perfect_ranker as region_proposer +import train_ans_classifier as ans_trainer + +def get_pred(y, qa_anno_dict, region_anno_dict, ans_vocab, vocab, + image_dir, mean_image, start_index, val_set_size, batch_size, + placeholders, img_height=100, img_width=100): + + inv_ans_vocab = {v: k for k, v in ans_vocab.items()} + pred_list = [] + correct = 0 + max_iter = int(math.ceil(val_set_size*1.0/batch_size)) +# print ([val_set_size, batch_size]) +# print('max_iter: ' + str(max_iter)) + batch_size_tmp = batch_size + for i in xrange(max_iter): + if i==(max_iter-1): + batch_size_tmp = val_set_size - i*batch_size + print('Iter: ' + str(i+1) + '/' + str(max_iter)) +# print batch_size_tmp + region_images, ans_labels, questions, \ + region_score, partition= \ + ans_io_helper.ans_mini_batch_loader(qa_anno_dict, + region_anno_dict, + ans_vocab, vocab, + image_dir, mean_image, + start_index+i*batch_size, + batch_size_tmp, + img_height, img_width, 3) + + # print [start_index+i*batch_size, + # start_index+i*batch_size + batch_size_tmp -1] + if i==max_iter-1: + + residual_batch_size = batch_size - batch_size_tmp + residual_regions = residual_batch_size*ans_io_helper.num_proposals + + residual_region_images = np.zeros(shape=[residual_regions, + img_height/3, img_width/3, + 3]) + residual_questions = np.zeros(shape=[residual_regions, + len(vocab)]) + residual_ans_labels = np.zeros(shape=[residual_batch_size, + len(ans_vocab)]) + residual_region_score = np.zeros(shape=[1, residual_regions]) + + region_images = np.concatenate((region_images, + residual_region_images), + axis=0) + questions = np.concatenate((questions, residual_questions), axis=0) + ans_labels = np.concatenate((ans_labels, residual_ans_labels), + axis=0) + region_score = np.concatenate((region_score, residual_region_score), + axis=1) + # print region_images.shape + # print questions.shape + # print ans_labels.shape + # print region_score.shape + + feed_dict = { + placeholders[0] : region_images, + placeholders[1] : questions, + placeholders[2] : 1.0, + placeholders[3] : ans_labels, + placeholders[4] : region_score, + } + + ans_ids = np.argmax(y.eval(feed_dict), 1) + for j in xrange(batch_size_tmp): + pred_list = pred_list + [{ + 'question_id' : start_index+i*batch_size+j, + 'answer' : inv_ans_vocab[ans_ids[j]] + }] + # print qa_anno_dict[start_index+i*batch_size+j].question + # print inv_ans_vocab[ans_ids[j]] + + return pred_list + +def eval(eval_params): + sess = tf.InteractiveSession() + + train_anno_filename = eval_params['train_json'] + test_anno_filename = eval_params['test_json'] + regions_anno_filename = eval_params['regions_json'] + image_regions_dir = eval_params['image_regions_dir'] + outdir = eval_params['outdir'] + model = eval_params['model'] + batch_size = eval_params['batch_size'] + test_start_id = eval_params['test_start_id'] + test_set_size = eval_params['test_set_size'] + if not os.path.exists(outdir): + os.mkdir(outdir) + + qa_anno_dict_train = ans_io_helper.parse_qa_anno(train_anno_filename) + qa_anno_dict = ans_io_helper.parse_qa_anno(test_anno_filename) + region_anno_dict = region_proposer.parse_region_anno(regions_anno_filename) + ans_vocab, inv_ans_vocab = ans_io_helper.create_ans_dict() + vocab, inv_vocab = ans_io_helper.get_vocab(qa_anno_dict_train) + + # Create graph + g = tf.get_default_graph() + image_regions, questions, keep_prob, y, region_score= \ + graph_creator.placeholder_inputs_ans(len(vocab), len(ans_vocab), + mode='gt') + y_pred_obj = graph_creator.obj_comp_graph(image_regions, 1.0) + obj_feat_op = g.get_operation_by_name('obj/conv2/obj_feat') + obj_feat = obj_feat_op.outputs[0] + y_pred_atr = graph_creator.atr_comp_graph(image_regions, 1.0, obj_feat) + atr_feat_op = g.get_operation_by_name('atr/conv2/atr_feat') + atr_feat = atr_feat_op.outputs[0] + + y_pred = graph_creator.ans_comp_graph(image_regions, questions, keep_prob, + obj_feat, atr_feat, vocab, + inv_vocab, len(ans_vocab), + eval_params['mode']) + y_avg = graph_creator.aggregate_y_pred(y_pred, region_score, batch_size, + ans_io_helper.num_proposals, + len(ans_vocab)) + + cross_entropy = graph_creator.loss(y, y_avg) + accuracy = graph_creator.evaluation(y, y_avg) + + # Restore model + saver = tf.train.Saver() + if os.path.exists(model): + saver.restore(sess, model) + else: + print 'Failed to read model from file ' + model + + mean_image = np.load('/home/tanmay/Code/GenVQA/Exp_Results/' + \ + 'Obj_Classifier/mean_image.npy') + + placeholders = [image_regions, questions, keep_prob, y, region_score] + + # Get predictions + pred_dict =get_pred(y_avg, qa_anno_dict, region_anno_dict, ans_vocab, vocab, + image_regions_dir, mean_image, test_start_id, + test_set_size, batch_size, placeholders, 75, 75) + + json_filename = os.path.join(outdir, 'predicted_ans_' + \ + eval_params['mode'] + '.json') + with open(json_filename,'w') as json_file: + json.dump(pred_dict, json_file) + + + +if __name__=='__main__': + ans_classifier_eval_params = { + 'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json', + 'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json', + 'regions_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/regions_anno.json', + 'image_regions_dir': '/mnt/ramdisk/image_regions', + 'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier', + 'model': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier/ans_classifier_q_obj_atr-9', + 'mode' : 'q_obj_atr', + 'batch_size': 20, + 'test_start_id': 111352, #+48600, + 'test_set_size': 160725-111352+1, + } + + eval(ans_classifier_eval_params) diff --git a/classifiers/answer_classifier/train_ans_classifier.py b/classifiers/answer_classifier/train_ans_classifier.py index a4dfdde..524d294 100644 --- a/classifiers/answer_classifier/train_ans_classifier.py +++ b/classifiers/answer_classifier/train_ans_classifier.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt import matplotlib.image as mpimg import numpy as np import math +import random import pdb import tensorflow as tf import object_classifiers.obj_data_io_helper as obj_data_loader @@ -14,18 +15,17 @@ import plot_helper as plotter import ans_data_io_helper as ans_io_helper import region_ranker.perfect_ranker as region_proposer import time + val_start_id = 106115 -val_batch_size = 5000 -val_batch_size_small = 100 -batch_size = 20 -crop_n_save_regions = False -restore_intermediate_model = True +val_set_size = 5000 +val_set_size_small = 100 + def evaluate(accuracy, qa_anno_dict, region_anno_dict, ans_vocab, vocab, - image_dir, mean_image, start_index, val_batch_size, + image_dir, mean_image, start_index, val_set_size, batch_size, placeholders, img_height=100, img_width=100): correct = 0 - max_iter = int(math.floor(val_batch_size/batch_size)) + max_iter = int(math.floor(val_set_size/batch_size)) for i in xrange(max_iter): region_images, ans_labels, questions, \ region_score, partition= \ @@ -53,24 +53,15 @@ def evaluate(accuracy, qa_anno_dict, region_anno_dict, ans_vocab, vocab, def train(train_params): sess = tf.InteractiveSession() - train_anno_filename = '/home/tanmay/Code/GenVQA/GenVQA/' + \ - 'shapes_dataset/train_anno.json' - - test_anno_filename = '/home/tanmay/Code/GenVQA/GenVQA/' + \ - 'shapes_dataset/test_anno.json' - - regions_anno_filename = '/home/tanmay/Code/GenVQA/GenVQA/' + \ - 'shapes_dataset/regions_anno.json' - - image_dir = '/home/tanmay/Code/GenVQA/GenVQA/' + \ - 'shapes_dataset/images' + train_anno_filename = train_params['train_json'] + test_anno_filename = train_params['test_json'] + regions_anno_filename = train_params['regions_json'] + image_dir = train_params['image_dir'] + image_regions_dir = train_params['image_regions_dir'] + outdir = train_params['outdir'] + obj_atr_model = train_params['obj_atr_model'] + batch_size = train_params['batch_size'] - # image_regions_dir = '/home/tanmay/Code/GenVQA/Exp_Results/' + \ - # 'image_regions' - - image_regions_dir = '/mnt/ramdisk/image_regions' - - outdir = '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier' if not os.path.exists(outdir): os.mkdir(outdir) @@ -80,7 +71,7 @@ def train(train_params): vocab, inv_vocab = ans_io_helper.get_vocab(qa_anno_dict) # Save region crops - if crop_n_save_regions == True: + if train_params['crop_n_save_regions'] == True: qa_anno_dict_test = ans_io_helper.parse_qa_anno(test_anno_filename) ans_io_helper.save_regions(image_dir, image_regions_dir, qa_anno_dict, region_anno_dict, @@ -91,24 +82,27 @@ def train(train_params): # Create graph + g = tf.get_default_graph() image_regions, questions, keep_prob, y, region_score= \ graph_creator.placeholder_inputs_ans(len(vocab), len(ans_vocab), mode='gt') y_pred_obj = graph_creator.obj_comp_graph(image_regions, 1.0) - obj_feat = tf.get_collection('obj_feat', scope='obj/conv2') - y_pred_atr = graph_creator.atr_comp_graph(image_regions, 1.0, obj_feat[0]) - atr_feat = tf.get_collection('atr_feat', scope='atr/conv2') - - # model restoration -# obj_atr_saver = tf.train.Saver() - model_to_restore = '/home/tanmay/Code/GenVQA/GenVQA/classifiers/' + \ - 'saved_models/obj_atr_classifier-1' -# obj_atr_saver.restore(sess, model_to_restore) - - y_pred, logits = graph_creator.ans_comp_graph(image_regions, - questions, keep_prob, \ - obj_feat[0], atr_feat[0], - vocab, inv_vocab, len(ans_vocab)) + obj_feat_op = g.get_operation_by_name('obj/conv2/obj_feat') + obj_feat = obj_feat_op.outputs[0] + y_pred_atr = graph_creator.atr_comp_graph(image_regions, 1.0, obj_feat) + atr_feat_op = g.get_operation_by_name('atr/conv2/atr_feat') + atr_feat = atr_feat_op.outputs[0] + + # Restore obj and attribute classifier parameters + obj_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='obj') + atr_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='atr') + obj_atr_saver = tf.train.Saver(obj_vars+atr_vars) + obj_atr_saver.restore(sess, obj_atr_model) + + y_pred = graph_creator.ans_comp_graph(image_regions, questions, keep_prob, + obj_feat, atr_feat, vocab, + inv_vocab, len(ans_vocab), + train_params['mode']) y_avg = graph_creator.aggregate_y_pred(y_pred, region_score, batch_size, ans_io_helper.num_proposals, len(ans_vocab)) @@ -117,66 +111,137 @@ def train(train_params): accuracy = graph_creator.evaluation(y, y_avg) # Collect variables - vars_to_opt = tf.get_collection(tf.GraphKeys.VARIABLES, scope='ans') - - train_step = tf.train.AdamOptimizer(train_params['adam_lr']) \ - .minimize(cross_entropy, var_list=vars_to_opt) - - word_embed = tf.get_collection(tf.GraphKeys.VARIABLES, scope='ans/word_embed') - vars_to_restore = \ - tf.get_collection(tf.GraphKeys.VARIABLES,scope='obj') + \ - tf.get_collection(tf.GraphKeys.VARIABLES, scope='atr') + \ - [word_embed[0]] + ans_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='ans') + list_of_vars = [ + 'ans/word_embed/word_vecs', + 'ans/fc1/W_region', + 'ans/fc1/W_obj', + 'ans/fc1/W_atr', + 'ans/fc1/W_q', + 'ans/fc1/b', + 'ans/fc2/W', + 'ans/fc2/b' + ] + vars_dict = graph_creator.get_list_of_variables(list_of_vars) - all_vars = tf.get_collection(tf.GraphKeys.VARIABLES) + if train_params['mode']=='q': + pretrained_vars_high_lr = [] + pretrained_vars_low_lr = [] + partial_model = '' - vars_to_init = [var for var in all_vars if var not in vars_to_restore] - vars_to_save = tf.get_collection(tf.GraphKeys.VARIABLES,scope='obj') + \ - tf.get_collection(tf.GraphKeys.VARIABLES, scope='atr') + \ - tf.get_collection(tf.GraphKeys.VARIABLES, scope='ans') - - print('vars_to_save: ') - print([var.name for var in vars_to_save]) - # Session saver - - saver = tf.train.Saver(vars_to_restore) - saver2 = tf.train.Saver(vars_to_save) - if restore_intermediate_model==True: - intermediate_model = '/home/tanmay/Code/GenVQA/Exp_Results/' + \ - 'Ans_Classifier/ans_classifier_question_only-9' - print('vars_to_restore: ') - print([var.name for var in vars_to_restore]) - print('vars_to_init: ') - print([var.name for var in vars_to_init]) - saver.restore(sess, intermediate_model) -# print('Initializing variables') - sess.run(tf.initialize_variables(vars_to_init)) - start_epoch = 0 + elif train_params['mode']=='q_obj_atr' or \ + train_params['mode']=='q_reg': + + pretrained_vars_low_lr = [ + vars_dict['ans/word_embed/word_vecs'], + ] + pretrained_vars_high_lr = [ + vars_dict['ans/fc1/W_q'], + vars_dict['ans/fc1/b'], + vars_dict['ans/fc2/W'], + vars_dict['ans/fc2/b'] + ] + partial_model = os.path.join(outdir, 'ans_classifier_q-' + \ + str(train_params['start_model'])) + + elif train_params['mode']=='q_obj_atr_reg': + pretrained_vars_low_lr = [ + vars_dict['ans/word_embed/word_vecs'], + vars_dict['ans/fc1/W_q'], + vars_dict['ans/fc1/W_obj'], + vars_dict['ans/fc1/W_atr'], + vars_dict['ans/fc1/b'], + ] + pretrained_vars_high_lr = [ + vars_dict['ans/fc2/W'], + vars_dict['ans/fc2/b'] + ] + partial_model = os.path.join(outdir, 'ans_classifier_q_obj_atr-' + \ + str(train_params['start_model'])) + + # Fine tune begining with a previous model + if train_params['fine_tune']==True: + partial_model = os.path.join(outdir, 'ans_classifier_' + \ + train_params['mode'] + '-' + \ + str(train_params['start_model'])) + start_epoch = train_params['start_model']+1 else: - # Initializing all variables except those restored - print('Initializing variables') - sess.run(tf.initialize_variables(vars_to_init)) start_epoch = 0 + # Restore partial model + vars_to_save = obj_vars + atr_vars + ans_vars + partial_saver = tf.train.Saver(vars_to_save) + if os.path.exists(partial_model): + partial_saver.restore(sess, partial_model) + + # Variables to train from scratch + all_pretrained_vars = pretrained_vars_low_lr + pretrained_vars_high_lr + vars_to_train_from_scratch = \ + [var for var in ans_vars if var not in all_pretrained_vars] + + # Attach optimization ops + train_step_high_lr = tf.train.AdamOptimizer(train_params['adam_high_lr']) \ + .minimize(cross_entropy, + var_list=vars_to_train_from_scratch + + pretrained_vars_high_lr) + print('Parameters trained with high lr (' + + str(train_params['adam_high_lr']) + '): ') + print([var.name for var in vars_to_train_from_scratch + + pretrained_vars_high_lr]) + + if pretrained_vars_low_lr: + train_step_low_lr = tf.train \ + .AdamOptimizer(train_params['adam_low_lr']) \ + .minimize(cross_entropy, + var_list=pretrained_vars_low_lr) + print('Parameters trained with low lr(' + + str(train_params['adam_low_lr']) + '): ') + print([var.name for var in pretrained_vars_low_lr]) + + all_vars = tf.get_collection(tf.GraphKeys.VARIABLES) + + if train_params['fine_tune']==False: + vars_to_init = [var for var in all_vars if var not in + obj_vars + atr_vars + all_pretrained_vars] + else: + vars_to_init = [var for var in all_vars if var not in vars_to_save] + + # Initialize vars_to_init + sess.run(tf.initialize_variables(vars_to_init)) + + print('All pretrained variables: ') + print([var.name for var in all_pretrained_vars]) + print('Variables to train from scratch: ') + print([var.name for var in vars_to_train_from_scratch]) + print('Variables to initialize randomly: ') + print([var.name for var in vars_to_init]) + print('Variables to save: ') + print([var.name for var in vars_to_save]) + # Load mean image mean_image = np.load('/home/tanmay/Code/GenVQA/Exp_Results/' + \ 'Obj_Classifier/mean_image.npy') placeholders = [image_regions, questions, keep_prob, y, region_score] - - # Variables to observe - W_fc2 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='ans/fc2/W') - q_feat = tf.get_collection('q_feat', scope='ans/q_embed') + + if train_params['fine_tune']==True: + restored_accuracy = evaluate(accuracy, qa_anno_dict, + region_anno_dict, ans_vocab, + vocab, image_regions_dir, + mean_image, val_start_id, + val_set_size, batch_size, + placeholders, 75, 75) + print('Accuracy of restored model: ' + str(restored_accuracy)) # Start Training -# batch_size = 1 - max_epoch = 10 + max_epoch = train_params['max_epoch'] max_iter = 5000 val_acc_array_epoch = np.zeros([max_epoch]) train_acc_array_epoch = np.zeros([max_epoch]) for epoch in range(start_epoch, max_epoch): - start = time.time() - for i in range(max_iter): + iter_ids = range(max_iter) + random.shuffle(iter_ids) + for i in iter_ids: #range(max_iter): train_region_images, train_ans_labels, train_questions, \ train_region_score, train_partition= \ @@ -194,100 +259,63 @@ def train(train_params): region_score: train_region_score, } + if pretrained_vars_low_lr: + _, _, current_train_batch_acc, y_pred_eval, loss_eval = \ + sess.run([train_step_low_lr, train_step_high_lr, + accuracy, y_pred, cross_entropy], + feed_dict=feed_dict_train) + else: + _, current_train_batch_acc, y_pred_eval, loss_eval = \ + sess.run([train_step_high_lr, accuracy, + y_pred, cross_entropy], + feed_dict=feed_dict_train) + + assert (not np.any(np.isnan(y_pred_eval))), 'NaN predicted' - try: - assert (not np.any(np.isnan(q_feat[0].eval(feed_dict_train)))) - except AssertionError: - print('NaN in q_feat') - print(1+i*batch_size) - print(train_questions) - print(logits.eval(feed_dict_train)) - print(cross_entropy.eval(feed_dict_train)) - exit(1) - - start1 = time.time() - _, current_train_batch_acc, y_avg_eval, y_pred_eval, logits_eval, W_fc2_eval = \ - sess.run([train_step, accuracy, y_avg, y_pred, logits, W_fc2[0]], - feed_dict=feed_dict_train) - end1 = time.time() - # print('Training Pass: ' + str(end1-start1)) train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] + \ current_train_batch_acc - # pdb.set_trace() - - - try: - assert (not np.any(np.isnan(W_fc2_eval))) - except AssertionError: - print('NaN in W_fc2') - print(1+i*batch_size) - print(W_fc2_eval) - exit(1) - try: - assert (not np.any(np.isnan(logits_eval))) - except AssertionError: - print('NaN in logits') - print(1+i*batch_size) - print(y_avg_eval) - exit(1) - - try: - assert (not np.any(np.isnan(y_avg_eval))) - except AssertionError: - print('NaN in y_avg') - print(1+i*batch_size) - print(logits_eval) - print(y_avg_eval) - exit(1) - if (i+1)%500==0: - print(logits_eval[0:22,:]) - print(train_region_score[0,0:22]) - print(train_ans_labels[0,:]) -# print(train_ans_labels[0,:]) - print(y_avg_eval[0,:]) -# print(y_pred_eval) val_accuracy = evaluate(accuracy, qa_anno_dict, region_anno_dict, ans_vocab, vocab, image_regions_dir, mean_image, - val_start_id, val_batch_size_small, - placeholders, 75, 75) + val_start_id, val_set_size_small, + batch_size, placeholders, 75, 75) - print('Iter: ' + str(i+1) + ' Val Sm Acc: ' + str(val_accuracy)) - - end = time.time() - print('Per Iter Time: ' + str(end-start)) + print('Iter: ' + str(i+1) + ' Val Sm Acc: ' + str(val_accuracy) + + ' Loss: ' + str(loss_eval)) train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] / max_iter - start = time.time() val_acc_array_epoch[epoch] = evaluate(accuracy, qa_anno_dict, - region_anno_dict, ans_vocab, vocab, - image_regions_dir, mean_image, - val_start_id, val_batch_size, + region_anno_dict, ans_vocab, + vocab, image_regions_dir, + mean_image, val_start_id, + val_set_size, batch_size, placeholders, 75, 75) - end=time.time() - print('Per Validation Time: ' + str(end-start)) + print('Val Acc: ' + str(val_acc_array_epoch[epoch]) + ' Train Acc: ' + str(train_acc_array_epoch[epoch])) - plotter.plot_accuracies(xdata=np.arange(0, epoch + 1) + 1, + if train_params['fine_tune']==True: + plot_path = os.path.join(outdir, 'acc_vs_epoch_' \ + + train_params['mode'] + '_fine_tuned.pdf') + else: + plot_path = os.path.join(outdir, 'acc_vs_epoch_' \ + + train_params['mode'] + '.pdf') + + plotter.plot_accuracies(xdata=np.arange(0, epoch + 1) + 1, ydata_train=train_acc_array_epoch[0:epoch + 1], ydata_val=val_acc_array_epoch[0:epoch + 1], xlim=[1, max_epoch], ylim=[0, 1.0], - savePath=os.path.join(outdir, - 'acc_vs_epoch_q_o_atr.pdf')) + savePath=plot_path) - save_path = saver2.save(sess, - os.path.join(outdir,'ans_classifier_question_obj_atr'), - global_step=epoch) + save_path = partial_saver \ + .save(sess, os.path.join(outdir, 'ans_classifier_' + \ + train_params['mode']), global_step=epoch) sess.close() tf.reset_default_graph() if __name__=='__main__': - train_params = { - 'adam_lr' : 0.0001, - } - train(train_params) + print 'Hello' diff --git a/classifiers/answer_classifier/train_ans_classifier.pyc b/classifiers/answer_classifier/train_ans_classifier.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a476b58d08ede1221bca8d1d9e83afa1f34f095 GIT binary patch literal 8282 zcmcIpOK%+6bw1V2xA+n%vdN(ytp_#K=)sX@hMLi6Byl7SDbGY6w;WB-uu>>?Rg=}K zu5MNpDblF&BJv^zvdPwfy)?YeFUUVIU?AQl$TCX@36OYYf8V)P&8DOn$inPq^|`Ne z&OP^>@7{Z={9hAekN5uYydn8t8NUzkAD^*ATzoAFMCv`-lDbvYm!w`Q=(4m1q&`>_ zhom(u^^u}DD)q5~J|XoJ1wAhH@q(U^`b0rbN`0!Jr=@jL>N7=gR$8Z|K35dyrCurM z1*tC<^m|fYD(KVFIwSS7Me&@}&xs4=tKycWeqLN?et}PZ2j!<|x+w0D{1g+vB;K;P z!}78u(W2#!$i`-<R(ZBE3N!hC#NSi@^)D^)|9%!_m`oZIc+BIUzYp;r{}vBeGMkV@ zB(bDbl5|AUQJH&f3BK8+9Q<06vZRwNK6>$U$)+S3kk+82mZa#8wgE{|9FlZc(kYHU zRly@YAt{VllGi8XM}v}0OLkJS8OdfPJ0(6wOiK#uTe7&8&L~+1iMg_NkhLe3ERmd5 zatLHLC+UQwA}vUqqgoZ{zk5x-WA&uWJzE*Zw4VBSh)=)q;WDz#qR8*>`z8GN19)gZ zOb*2`!K6gmPO_5>f@lVvPLvFRI*6Qb+e>INCyZ^sWAAuD*Nc+5yr%7kzP;rnjU7Aa zI4)Yq$=i2=71}x(VZTOWFLD}($v6tTj_rhD$9DZjl1!r_^0xg>*jLFFq_F)?!`V`O z45cgEcD6l?jgm3+YI{y-`>aY%pcp4kl-PdgdIu^BO4zZN*s=d(Px~7^wMbREfz$AI zIsvqfwcT%TV}AbjPLdd5+x9Epg;8kIMkk6q=*$_loderXFj?BxSM%i4P;Cr1&LMEN zydYL?yL(=oV5YI^mUoZA)zjuKbobe=rKab(YN8bg5Pigd4Qh3#)Am*qCu}>1s}DP_ zx4Pkl&wsF9{4@e5j{T;O8LdM8D=)H-s#cT8@oDL}ukuFsP^ZKt+jA1H0x9t&N*x%o z7OmOxs8zOx@Hv6sv!ybA&1bS)DK$9%ukn8p{~A**tP7&4g#Z%C5F{F6iy2}C6r4Fj z$bf>gXYh4U1E3HNpy2fxyc!fdJwrHw8U}@!1T_L`RFaak6x9%JaQ^rwT7&Ugw15Ke zTc8|*6V2HeJ+Osv9MIwkZ5^X-KPpK!uCb5)2v$^DS~<xer@&EJ#8MGru2?zKvT~VL zE(kXh_hSrp7g38RLVjFckBAQF^+2ej4LJI3Uk|DZLjhV~sFRXRNehe-^a?g4zzL`h z$eotgC5tFh;0I73o7W+L4uJ4Ubdanfz>L<cxHGEef}{(CWp@^qTh!hF5AB^U%ImUZ z?-c`1>3}6kmwE$Et9B3q9V^mW5Etk^kA7Bm=3Pw{Z91z>i_!w?Gm@T_3RmHrr01js z_V9-hs%mkzr>P8~?UX>D6vas*Xm93RgDF{jJt|pMwFkqhr1QjBt_NflMtTj~oYy`O zg-M=QdtA^O4C&Qee3#}6s`({xu}!@vE;gvs;+_!~TN8Gqs%(!)wydnqOS-I<x~vWY z<=pe?AfsgfZ1z53Hd_(*f@D`D{EuFzi|iC%Dd#hCFX^D4l+();y0{G!tq2QeWRX8- z&CfaWQ$>#fa?gGL0-fE3Mi*7UWta%N0385A<DSs>Rmgu9Li*S*8NR1`!nR>Pqd=cX z@Owk#%DY91ioDW`mi2*SihfXNRB^9L<y}hCESLpCsFHmk;SfC*MD9T8B~)MS^Q)`6 z1lJ0407b5_Ci;hg^jhI%sDYOu#8S-beK5GrwE?R~FBT-bAt}Ds<>jzMzZd}hw<X#A z#Xz=NNV%jsUsKDrmL<hG<hrD*%8C*2p}02+W0+nQcU3AIn-pTYJ$f;f3u;|f{zHQs zo6siu@gPGAA!3Al_=b=b^~Rgx-Yf)|9(@8<FpD>sH^4jxl)jK<?A_|Ej&X?J3{uBi zU{Ng$VFDHQsL&w48q5vytHJCe376Gy?B6roRW;o2ihjRiKZJG*{(+fYk@O?32fmhj z3nw$myZ*My^Fte(?nin1>f7yWZ2vfK|LE=Z=WPEZZ@>C>`##!}52STf+%<7;8zlO; z4|9MD^Fr7Eq-P=n_XESpHAxX&itK=40uwf-k0rS-2|Ne;$tMMpPe=y#pXy?*ankSt z`W<}wGvMxRaqmcmW@rM(xGe~FF3V?c)!dPTN;&yZQh5Gty6>YGL*?vqVb!NP=B^|+ zrG*8(tHwR}M+?-W7eBMIFBA~53U6tQph~}>#(F9>`;+^5aVoth*?q|#NVYE7mwGC_ zM+2>;4+zX_?p=cOT6&)lx|XifU~Ab!232%s#4LGwQ&MOMYgHL1u%`=o522y%SdS#b z_pxMONwy&&qDUR(*btoWdqV(pef?=QSRGHS(N~-`zT<IcB!SKEvG8};Qtn9lIk4=p zR1A3+GbWh0_hlYX!W}&PvrBsWttH(Q_kk{{;RHbycio`ZCdF3VFHLDfOB<WJ;V{16 z6}o&XbGa`6FIP%G<5cdt4^_4n6x`iMMw*!^<pWFLVMn0*^U-6V`&b<i8&t*p%IpnK zB*YGBPETeB>Fcvm^hrLI)@R~ws?F~Zh@lT{34c_RKB2jfs?WRM6DZwZ@ijZr(J|s1 zw%~L^`mZE?g0s$$MU?)WesD~i4|*C?)KkjEp(p#E>~6xWa2hzh;NSp+`9djI<k`v( zaPEAnXLob9ZN>QkI>BMv))G(p=CmDaMU53WIpE-~mx-c{mkM5aI(rF8UU@oOE!#;V zyN!c(P~&AK!EI6+QKxH%cI@nXc0O7!S^oe06E2)Q4*UO0>R;j#Z;QAtU_2&^Yghs> z<lpb_$20hWfYnAP+`qL-QBBWD^xmdz5S+JB1-9v32sehVr`I+N(O)#1-!P6Bz!c~E z)mNR!wY~P1=NfX~Ys4`pz{Y0d=IVEa=K0kwuLRHdcn!G6yj&d(+>P@anvU7(E8SZC z?pSF{??<~>Ggyv;c5Tj+3iFz|OC4ga#PKVMy?edyKBval<>-ypaf|O5SFpcsE>V8y z*~wn$;k5ad3{~yA-o+f(Y1=!v_Sr$CccFbJvI9R(;`12ytrIzI50|~Tsy^nq)mMJ9 zQ{{ka5LH*=GJ0LPQ&pXnSI~I{R`Iz%Ga6kmcqM)heb$4Z+Ql`Hw=-~ubl=$uy!cL4 z6?k5ht4YVGP;Ew?b~SE9ymzD51u!|*3zzc5cLLmdtGK|sopx}DZuH$_-RQxnUq=1^ zd9i$a(6wt`ycZ<#YJ3JgzmFAs*t=$~9^Jujgx$mVG<vM(_x)<8S&a~fxZ=81T}Bv* zZouoa*X=ZROm>F$%j4qT@7KsdeekLClxAU~0FVF1RWoLMKEYBzX?kh(xf4{MwW|<_ znqk&_(}`&}N@1gEVZ>K|Nb{Nslh0^0?ETp2VaN0kyIDAld!y?%V}>pXb=CvdNgQ3` zLsGV`md?hMaedZUKM4G`5{D$1ETEOYWom_Cv>kM|n3hPquK%z945ausJXY?ZMFApd zCd7p&AtvcG5`W)&=HWsQP$PVFV&&W=Gn&FXpp&Y4?Uw+p8a39*A)63;cAjfd_cuUg zcY$4<7?}%(o#~agFe{@y8xgR^(=>r5X=qDiSn_-fMf@<@^E7h$Ay7H8pX)UpxQo3V zIo%!AMk!nQBCo`}P+#f=37hU-5>uYSSm1@pB)SkO8l84G=fi{(xdZ*+_RJV<9oiuy zlj{!9CKDY9c<CQVUL4ry1Of6lMTTXA+I{UEn&iy$_1cs5FTZ*GOpR&QmJsy}YZs}6 z+>wn0W>Ah)Q(WENh8=0QL)h4Jb>87#J2$w4BoHkZz&&-l%F`{>rkEKK$L=)kUi|Bb zRtM=hMo0|Qq|>hht-jw){5GQ7z!h_FENOo=3040JGR;O2?S;uH)EuEgpA7m{*v*x5 z+!{k%*A@;9H`R=qTVr^rX&@6`Nd6`>NH;VPqgIM{_L|MWtFc4ADIIEWKMozUuIn+~ z4-<bJMeqi?yNbdRF*9e!Z6^qnSvTk;OtMipX**9O`oI**1u!!?NxRY!RZasI_H>uV zg1zaCE8LrPa=OAc{NHJtg^3YnI?S|ain8g$7wKWtQ4bWytGRF{c<8;zX%++Yh*9I_ zQ6n_hF7iPysSyuq45%9Om%5&QxUaFMNzNKELQ#d)Chy5e?a!LpR7JP1ZNs4d7*pf6 zs0?0uUYFiyHsQk>cN%p)<DlFs6_FTcJY68N6t0SLe_!s{dI-o*B1K<@nx0HJFAjw7 z#>#~)TA<6ag%w;SbI8hrgl=+ZJNsrLI9J60<S+q79o87uh81(c<VzUu3lbB#Y%$4X zW-gDmQ619T)fIDkTdZj_yR@1Rj==rW9Q1Ui#`ATP%A)~dl8#Q{*bp-(Zmn=zhQ2xw zvbU--0~bU><Ey4%VLXgmxSem|#=$f>%Z;I@HCNvB*6Lhg+M2Rd0Buq*W_;IZ@0t$e zqu73n&^fH#W;;Dsy`7PVsIr|Xat_U2sCbk|!5b}HE@Kkvzy!GMyKzku{j`7?mG38c zCR|U=y}Wat4_7$R%A>5f2AS673`}nPsH-lL;y3Gh8S;E5W+vKB=mU7!5W$IZJDZ$8 z=BUe>`##ZEgM@f&YsnZ=tzlkPFiI75bh~AVOAX9Oc2-dMODE-Pe*+nxQ0f7vW{sD| ztYe?jj5S@FvZkyVYpy(v-*F`Ar}a5IFkV`+7SJ+fO;~e$mY1w0YhFi;qR#~ChO8;@ zm?<wIso(z@ElriCHRV5w-|6y1dD!A-0@5m_>5?g-jC!qu^l|VzWt}Vyl_w!}2Cogq zO3USWa4B13r3!u*OGDNu<l>3)+0q<&TUGWd4MMlRf4p)S)Amu-f-fI1%kmh0$$MV) zpBUhLhpb6`DN*P&Rvs!*m!<MJdd-!Vq3wy%DM&WFX3A@r^AX?K^0I2bfKhX$i`Gee z7OY|DI%6%t;tD{v-R64}TQ`_ZF9<p-m$4*fS3ZnuU$VemNbj2rJG@Dw>jeH5Ua#mj z#{DWMi3#1ny!`6xXx?Kn#Ez(T7q6sLERGVuL?5j+x<d^qCZj;t-7@!IFN`}8-rT*a zJz(219>86kc=}%U#5V}VdVW26mB+w)M>T`2SDnAxqvPWq8u?p$IO!a{S&DBE-rvFZ z9o6L9tIqlI+L|=YNza@UdB8KT8f#D4$lXVyF|XvEQel%<$K!&{DVW!Y97u3Xc<Dtq zaBytW9HyuTx#Lx}4Kn3I;&kQJUV_krPc0_JgOLVh?h`y2*!B?=DD@Y9Np}~wsJ+0u hM;C}c#DiW=e`oX;pQ93L;mP<6Or0M6ql}MRe*!`~r4s-E literal 0 HcmV?d00001 diff --git a/classifiers/tf_graph_creation_helper.py b/classifiers/tf_graph_creation_helper.py index aae6a23..db49175 100644 --- a/classifiers/tf_graph_creation_helper.py +++ b/classifiers/tf_graph_creation_helper.py @@ -8,13 +8,29 @@ graph_config = { 'num_attributes': 4, 'obj_feat_dim': 392, 'atr_feat_dim': 392, + 'region_feat_dim': 392, + 'word_vec_dim': 50, + 'ans_fc1_dim': 300, } +def get_variable(var_scope): + var_list = tf.get_collection(tf.GraphKeys.VARIABLES, scope=var_scope) + assert len(var_list)==1, 'Multiple variables exist by that name' + return var_list[0] + + +def get_list_of_variables(var_scope_list): + var_dict = dict() + for var_scope in var_scope_list: + var_dict[var_scope] = get_variable(var_scope) + return var_dict + + def weight_variable(tensor_shape, fan_in=None, var_name='W'): if fan_in==None: fan_in = reduce(lambda x, y: x*y, tensor_shape[0:-1]) + stddev = math.sqrt(2.0/fan_in) - print(stddev) initial = tf.truncated_normal(shape=tensor_shape, mean=0.0, stddev=stddev) return tf.Variable(initial_value=initial, name=var_name) @@ -94,7 +110,7 @@ def obj_comp_graph(x, keep_prob): logits = tf.add(tf.matmul(obj_feat, W_fc1), b_fc1, name='logits') y_pred = tf.nn.softmax(logits, name='softmax') -# tf.add_to_collection('obj_feat', h_pool2_drop_flat) + return y_pred @@ -140,73 +156,110 @@ def atr_comp_graph(x, keep_prob, obj_feat): logits = 0.5*logits_atr + 0.5*logits_obj + b_fc1 y_pred = tf.nn.softmax(logits, name='softmax') -# tf.add_to_collection('atr_feat', h_pool2_drop_flat) + return y_pred -def ans_comp_graph(image_regions, questions, keep_prob, \ - obj_feat, atr_feat, vocab, inv_vocab, ans_vocab_size): +def ans_comp_graph(image_regions, questions, keep_prob, obj_feat, atr_feat, + vocab, inv_vocab, ans_vocab_size, mode): + vocab_size = len(vocab) with tf.name_scope('ans') as ans_graph: + with tf.name_scope('word_embed') as word_embed: - initial = tf.truncated_normal(shape=[len(vocab),50], - stddev=math.sqrt(3.0/(31.0+300.0))) - word_vecs = tf.Variable(initial, name='word_vecs') - with tf.name_scope('q_embed') as q_embed: - q_feat = tf.matmul(questions, word_vecs) - # q_feat = tf.truediv(q_feat, tf.cast(len(vocab),tf.float32)) - # q_feat = tf.truediv(q_feat, tf.reduce_sum(questions,1,keep_dims=True)) + word_vecs = weight_variable([vocab_size, + graph_config['word_vec_dim']], + var_name='word_vecs') + q_feat = tf.matmul(questions, word_vecs, name='q_feat') with tf.name_scope('conv1') as conv1: + W_conv1 = weight_variable([5,5,3,4]) b_conv1 = bias_variable([4]) - h_conv1 = tf.nn.relu(conv2d(image_regions, W_conv1) + b_conv1, name='h') + a_conv1 = tf.add(conv2d(image_regions, W_conv1), b_conv1, name='a') + h_conv1 = tf.nn.relu(a_conv1, name='h') h_pool1 = max_pool_2x2(h_conv1) h_conv1_drop = tf.nn.dropout(h_pool1, keep_prob, name='h_pool_drop') with tf.name_scope('conv2') as conv2: + W_conv2 = weight_variable([3,3,4,8]) b_conv2 = bias_variable([8]) - h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2, name='h') + a_conv2 = tf.add(conv2d(h_pool1, W_conv2), b_conv2, name='a') + h_conv2 = tf.nn.relu(a_conv2, name='h') h_pool2 = max_pool_2x2(h_conv2) h_pool2_drop = tf.nn.dropout(h_pool2, keep_prob, name='h_pool_drop') - h_pool2_drop_flat = tf.reshape(h_pool2_drop, [-1, 392], name='h_pool_drop_flat') + h_pool2_drop_shape = h_pool2_drop.get_shape() + region_feat_dim = reduce(lambda f, g: f*g, + [dim.value for dim in h_pool2_drop_shape[1:]]) + region_feat = tf.reshape(h_pool2_drop, [-1, region_feat_dim], + name='region_feat') + + print('Region feature dimension: ' + str(region_feat_dim)) #392 with tf.name_scope('fc1') as fc1: - fc1_dim = 300 - W_region_fc1 = weight_variable([392, fc1_dim], var_name='W_region') - W_obj_fc1 = weight_variable([392, fc1_dim], var_name='W_obj', - std=math.sqrt(3.0/(2.0*392.0+50.0+ans_vocab_size))) - W_atr_fc1 = weight_variable([392, fc1_dim], var_name='W_atr', - std=math.sqrt(3.0/(2.0*392.0+50.0+ans_vocab_size))) - W_q_fc1 = weight_variable([50, fc1_dim], var_name='W_q', - std=math.sqrt(3.0/(50.0+ans_vocab_size))) + + fc1_dim = graph_config['ans_fc1_dim'] + W_region_fc1 = weight_variable([graph_config['region_feat_dim'], + fc1_dim], var_name='W_region') + W_obj_fc1 = weight_variable([graph_config['obj_feat_dim'], + fc1_dim], var_name='W_obj') + W_atr_fc1 = weight_variable([graph_config['atr_feat_dim'], + fc1_dim], var_name='W_atr') + W_q_fc1 = weight_variable([graph_config['word_vec_dim'], + fc1_dim], var_name='W_q') b_fc1 = bias_variable([fc1_dim]) - h_tmp = tf.matmul(q_feat, W_q_fc1) + b_fc1 + \ - tf.matmul(obj_feat, W_obj_fc1) + \ - tf.matmul(atr_feat, W_atr_fc1) - #tf.matmul(h_pool2_drop_flat, W_region_fc1) + \ + a_fc1_region = tf.matmul(region_feat, W_region_fc1, + name='a_fc1_region') + a_fc1_obj = tf.matmul(obj_feat, W_obj_fc1, name='a_fc1_obj') + a_fc1_atr = tf.matmul(atr_feat, W_atr_fc1, name='a_fc1_atr') + a_fc1_q = tf.matmul(q_feat, W_q_fc1, name='a_fc1_q') + + coeff_reg = 0.0 + coeff_obj = 0.0 + coeff_atr = 0.0 + coeff_q = 0.0 + + if mode=='q': + coeff_q = 1.0 + + elif mode=='q_reg': + coeff_q = 1/2.0 + coeff_region = 1/2.0 + + elif mode=='q_obj_atr': + coeff_q = 1/3.0 + coeff_obj = 1/3.0 + coeff_atr = 1/3.0 + + elif mode=='q_obj_atr_reg': + coeff_q = 1/4.0 + coeff_obj = 1/4.0 + coeff_atr = 1/4.0 + coeff_reg = 1/4.0 + + a_fc1 = coeff_reg * a_fc1_region + \ + coeff_obj * a_fc1_obj + \ + coeff_atr * a_fc1_atr + \ + coeff_q * a_fc1_q - h_fc1 = tf.nn.relu(h_tmp, name='h') + h_fc1 = tf.nn.relu(a_fc1, name='h') h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob, name='h_drop') with tf.name_scope('fc2') as fc2: - W_fc2 = weight_variable([fc1_dim, ans_vocab_size], - std=math.sqrt(3.0/(fc1_dim))) - + W_fc2 = weight_variable([fc1_dim, ans_vocab_size]) b_fc2 = bias_variable([ans_vocab_size]) - logits = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 + logits = tf.add(tf.matmul(h_fc1_drop, W_fc2), b_fc2, name='logits') - y_pred = tf.nn.softmax(logits) + y_pred = tf.nn.softmax(logits, name='softmax') - tf.add_to_collection('region_feat', h_pool2_drop_flat) - tf.add_to_collection('q_feat', q_feat) + return y_pred - return y_pred, logits -def aggregate_y_pred(y_pred, region_score, batch_size, num_proposals, ans_vocab_size): +def aggregate_y_pred(y_pred, region_score, batch_size, num_proposals, + ans_vocab_size): y_pred_list = tf.split(0, batch_size, y_pred) region_score_list = tf.split(1, batch_size, region_score) y_avg_list = [] @@ -214,17 +267,23 @@ def aggregate_y_pred(y_pred, region_score, batch_size, num_proposals, ans_vocab_ y_avg_list.append(tf.matmul(region_score_list[i],y_pred_list[i])) y_avg = tf.concat(0, y_avg_list) return y_avg + def evaluation(y, y_pred): - correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_pred, 1), name='correct_prediction') - accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy') + correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_pred, 1), + name='correct_prediction') + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), + name='accuracy') return accuracy def loss(y, y_pred): - cross_entropy = -tf.reduce_sum(y * tf.log(y_pred), name='cross_entropy') + y_pred_clipped = tf.clip_by_value(y_pred, 1e-10, 1) + cross_entropy = -tf.reduce_sum(y * tf.log(y_pred_clipped), + name='cross_entropy') batch_size = tf.shape(y) - return tf.truediv(cross_entropy, tf.cast(batch_size[0],tf.float32)) + print 'Batch Size:' + str(tf.cast(batch_size[0],tf.float32)) + return tf.truediv(cross_entropy, tf.cast(20,tf.float32))#batch_size[0],tf.float32)) if __name__ == '__main__': diff --git a/classifiers/tf_graph_creation_helper.pyc b/classifiers/tf_graph_creation_helper.pyc index 9b27a5400c59cc016fa82b56436b143641f67169..02d92a889c316357acd64391630440f3d31982be 100644 GIT binary patch literal 11829 zcmdT~OLH7Ya?Tz+2LuQZ4+0=a5PT0`fCx&IRx7PYkutSj3A<|8gbYjDZcNt<8XQay zx*H^+n6Wo3+PmQhFT=MDNBHE&!S>1Fo9%;r_SNA~_+bA8N9<`&_V;CVHwIitVQmNp zljxf2tm>-ltjsU7vbyKzBST+&_KQbV$^IS3?_FHUTP2YIe~vUn)}6dwlJ!zv?~(Q1 zyxu2U{jxqF;eZ6_&?9&F(~F-z3Ht5Nfc+W7Pg#N?{2Y^DSnjM3N-!eNF%8JawKBws zgHZ{_<T<7smf(a0<MP~*^$`gsB$(9taS5g*IH`4Xo|a%n>tkA<)jA$LrDKjuFsBtf zep)M&5}eTr2AS8&NeRwsWm?RP1Pk)GB=H-i;GC?_iW!itQ}VgUCtzt#)=x`tUV;lU zy?#d9=Ea<m&%xqZy<HG<mbdJU=g*7jm8}bWe*L1Di?Vfz^%wstjyWObvY2!1&B`1r z!9@wsqvU)7IbV_WimYE1GcM+en5&@6`ZXQ#x&)UcxFSIXPdV#1#7ty;%nb>y%IxQO z{$}3wns&XeT^GenYS#)!%zB$eeF)4LaGW#q`Jk*X=@_r*3)~WOOSYDka+?$|?5ttU zKKi(+O0!zI0WAvu2me;WUR2X#D4M(5UbL}gs%er^(|GTvX&i3srX~r0;-KF!=;-6s zjGuZzxUB>~jrR$o=n<QG7&Qx8xD1;1d>X}p_rz5BXn2DIa$>)kc(v*+604vIlv~Kv zeO$>s6sgFTqh(39dZ>MEzLe7eu~*s-Ul1A%bx6J?DRjDrl;1};(0gTewbF-CR?vb0 zt5M5@oF=)1;=|oW8nzl{;fWuI{zk(j3+7pvqzfB+3+bkxE;RjZ6QZM3`cb6hEv>11 z>L&H7QKJEY;ce40S{}xJYx8?%FVVY4_uTK^d-sF)R?{A|HB3_<yO$&;<`eJ5aTHf{ zKt2lwJY_B&%<dZCqk_6<kvv3gX*1e3OR3-7_V<?VM}b*-Xqt~cx|jb{8-9|6wa~=L zQd;xs<j1SVSS>EdrfIZHyx7{)#UTIrTaRJvgfA%0oHOs}GLY$F85VJaWq>k4=+EZ4 z01rMrDV>s}5Oaz0hOkk2YiCT>P+hH1#w+&yI7|>$Q<w8pIIxPMcx8{TTD-URqy`?* z%THlwzC0)OqFNz@<RlsrN?)9^GvZ7-csG%?DjBBk4B)ztD~V80w?rYRLK#3QpM!|^ zRD?YT41lCyZ%`TPtm^3$#MJ4b(r!UZ`?Pz%1c2p1X+vqHr_-ygAeN;K03TA}N@=x1 zOIfK>UGi&{tdL#I3T`HIk%J;u?qvU&zQo=U$&%B*)$q4B0{=T_QBQD*+Ka+VowxXR z*V^8u-!gc4k!|BQ^EOGTcu5)r=1Dq&YS;|Z&~IQ}8oRbTjj%%L>%%Ki)95p?33jX5 zfPT8Ejmb`&y5uUQuBGvAv+AcN@S0J)?Kf1%9%ZJksGz2hnr0Hg{FFa(Q}df%*eo=X z^Sg5#k4$9z9g#!{pifO$-`uy(Gaw~HD8w0dPM0QNp}MFWst&n!s%B<(3b=$qT2&-g zN{N;V7*)RyGo|tYH>p^BUyF?@@M2lqSv0t(S&#uWTPs<Lw;DAQ;85ybDB3Qv>tz;K zNW*mm$VCO&2>qn6T27cyr^K0YR0Y7=6SY}NhR~{#xrK^#BoKnNm`p^FhW^!IZ)&L& z4j`@6lM9Edp1>0Z@Po7|RfCov1YxtTQ?B0o@V%F%L^EwRRZ+>&a$qG$kZ}n?DypNA zl1q<!9almo#8D<Ty;c-83Iw4(sbKD+(k+-maVU(EP%8L?gm7C>G;joPQHh-1)e2R@ zD$(O4{LDJokdwRANOy_FD=ZGGlB=u2hqwJ_%AW$4t3u++9*um|=<^x9<!8&}FP9a$ zuBYk(sa9NjmD8V~V&#+>k~#pZIuH!xl!YW+(F*ZiwGAR<v>Ma!9=?CbF#S@e7j@ts zAX7L4Xke{JR-vg7GoSrqdWnbcvlUGxSq%hUmioP-rW(>-L+ZTyItmSM?yKu>))!h0 zziKw4MquKFS`;tT<7l@TEC3(VO&bEz!8IeVo+iZcmFO!UsZcT&4Ct4-Mpx6QH6lN~ zz3k4TO}na*Y$IOLmDT9zu`w;L6-OKHG<)GX0<;*%s9nqf(_6b~avRiyP2ya1#(}AQ zK-3ZaRotLH7R&wxZVp*53mV0;!>PZt>~Q2=6dIK}>TH$!eyJ5U2EDU7(9b1?rf}K+ zo?EoTR0fKWb%z;3UjSMNGEncAS|1b#BUr`xw@h)0KIqW<73!l9_ce9MeFI|pdYa>Y z3y-*OvQRyt)UDHwQhNWW>Th@i6O9$QN!Ta3-)3hG7PkGm@$8Nxao-}DCb6??62`oV zntHY&L3kY-ZLX&t27GGbBhowK*DvcG=X?h(8faSGn|J0*R{tif{t<mn;bIt|u=55e zR+FGZz>^mja$l|vo|K<zP+1@}G^x+RkuaPDX5&_$?O3$~QDaZ&3&K<=4m1q9g6(Xl zU+3zVE!88^2NjhD1i=uLi-0zm_d$pXBG}Mym5-NY4O;xs#~H&z%J9#nGsG~MZbW%Q zYC+LbI>(ef+*<aAb!5vPMn0zOA-_<tH<C{c_K<5R*gLN59ar{7^FCm2^dNg<(jJp7 zaD{Em32r`-;dl<KgJmKUF(PLld^#$facPgs*0B7rBs*IUB-EMEAxHQ$av<#q8XhMf zRhc1EF~$uI=n&8J!}kx!;DpNH2hy3M4A3V<GGtt4)pn(G^56nZ=Mn&RfZIuFpHv|r zdxH5e0fao$C1kAB327r6g4RjQ<aB1U*7hvC1LKSp-F{-%t}!@6ky|0bUR{M^zdB`& zD19wAbz5%Z1uuxBR`KL=NGmO32rBFUZ5LA0f$K+OZte`3!ew_KjL>xVq49g3u1IKA z-`D)c0@D|}u~|UI0g*po;>`sW60#63df6RGBdUi;g=p*-)l!5D75WV8^%y;IS75j7 zF5k;#QbmVBV5kw8jc$5Dj!k3Ny~B16N0COmsYN1eM0MZc((Ss}@;<S#yAk?xgr0WC zZnyJMvliAhq}%q>?cIj^1N5oz^=#Z!fFNI8QvmL-<-@KaWdQ|HHa5EIepekCgQAaB zf?GO$Mpun+fBoQy%UvVdx*gHh?UQ+(Vd!!;qSmdBQMh)xWhG?WF$w_d<Wkq%4_U0B zP>oz;+DK~~thwK3LqCc=1QrH4H~iD$8_k^@$EcKSHQH`j@sgp0t1HfUX}mP!j5_DA z85wrY9k`nVrR-dArkrwV3N0tm?^<cTq%j9wV1fBRdKvQ{mzXxtcme@uhWWtBLopv9 z-;Mbe(hJP*Q#=7&;Ll<{upJ{?%ttB$bc^}Gd4MqM0$|Vj095hUjrl`ea0GgR`M~x= zF&`ig_6p3mkZ#$-$ivE>#dgacu)SapIpu;q;5^s^1cN=`ukC{o0Gr+H0Ul4}m_L3H z^O1DUF@J){d;s$&bx2@7JOi+O@*w6TDQz*Ip!+45|1(8*?7J=C^JhYFA4rYVIKdil z51)Xk4zvOC(Kd5%8E2Iryxdfd_tWYPPUUnAWWB_U(&r?YmG&u_<j<T;PQxKgcTRUL zmh~M-ltZLQony)~x$igw-@$x`^&Lp?DuNv5Hr$8k=DyG3yDW$czk-u+zsH9^Kyfry zt1jY?F#KUIB2Qo^+??LaQ{IPp3B?`@Mf5Vdh0mY}kdW0dTsnkS>p+l0KX8kFh&TA} zzuvLPU3eAuBT)B!iA!I!??_yFT6H4sZ?N{8EDERLKH}{fi;r3S77KL@M3#NWU?Fs! zG&%+Mw^@9`;!_qsWZ|*+9Tr?o*JrW8f<D$ofG4hDQDecEaW`2Isof?EMQDcMYdld< zl(u&&X(06$d+o4@SrEA0$0&|w@$Cizrxg2;y^G0`2PnQCWDjARSa$lg90kPBARPw) z=gV;ebOP6XT*;rp=>w4vT4&(tXzI-d@O!<Chb_f|`M0x37|@EHGLi|54sEPVvjk$I zw`S~skNdzqk{h!fOOcwJ0U#z27Uy+(lD8ji12JLh1@Z#G@Su(C!7GDVjd0P2DC{sq zMu#%dlPC!Kh{A-zC}L#38{%2pXP5}eO((h`5Xa+_2ia4keu2#Z0%5R+Q$+kB1Ru*c z0|<rjEN<;)0N7ZthmpJEV9Oo=vS80fz?MBA<uL);kv*Ve!QQEZ?9CAtx9|uy1E(3k z;~4X~-3$N|$Ft4A8EK!9tx??!{EKb|=5<H_F%TAf&F78*o&Ya>me8KU>CAKf{mlTl zfNY5D^5PdxXF(TwT4p=vC=IUqjLf#r>0r1a_fF@0-gc_HjVF*!=R)3gy1R{MoKELr z-UdzDACnw8olBH5iB`&&^=+VAXQh2!7Z2y6c#kVN9a)15N?%aA9(GU{k!!f9^m9tT zn)iptyrlH=O2?Vz1?k`%6eW&4QQp*Yk&y+CQJLa+tp|sxokeN>Sc#XF_;C-`07s`& zXvk|z(hRgqMZ5fgcDY(~VGVQOG1hKM^L6cfO*{XicD`P8X3chH4NT!N)~@r=rhQ%7 zS7fq%L)sO(&YRM{%3pAO4ZodN^lhzWyrOSwRqs|#hc(>F)$N7~V)>vo+{9wYiw~r8 zn`?#jz~(yJ627V{gj>5puj!aDQ>YN+zAWumF`%?x(|H$l@H<_DgDdP{?`Z3iE;+>T zx~$q=DbRbh!c1)D+L0lluhNrQv%O&gmwQ9&*a&O+k!@l(iHxpy^knzYh~WQGF!0ZK zlb3@5<}$kuuXW4l>S6Ym2LqId3%F-AN{45kw=qJ*Hg$_Fc3^GI+tIC$uSedtvu*?} zfl&m-Ztoa63>0@eS=@k;7<}Zn|NV}}9Xp&k?{QZL*-`I{K{E2-T}_!4<gO>i?&`*A z({9DcXsx<z{jvrT6QH;}{nzbpSD>;VY;gPgtA;E(2R*2Jm&H96+{W5yg)xP`qb30L z)xzczn+|lDP_yYk>hXSx(25aRL%Rch%5KkC?6LT!V;?Gj`xy(3eYj}ZCk?vVxmxHp zm0IT_I<UzAD1=fqB<ir~cUk-%iyxuT_ozmuR^xJY-Q}uc+j3R0cUH%Wr;y=L`KeZE zZBv6XN=$2ZPydIUj^@ve^|IEuvCj11sN*v%)qZ^T=inmwDGEV)l$&ZBbIl*vP=EFI zddL|<a=-f;EsY|2W7IZ-@$^4NFXKX7r{(@HYQ59WF>K~B3&t5OT`SF$s+10;MytP% zE8${s?};s^9?Z0H+S(&oq7_5{HU5C)AmWS%vBrQr1Eeb*F5tnZSC~-+`r|7B4g|jn z|2zoC4Y0=#pdQ_vowsF;%}CZ_9`1~}DM!G#s>g~eA@~r@-0Oj^%~Fxuq}2#hHST9I z4&Cj^x!-D;W}uq{Z9>8a$K(~hna1JnyQryS*znWprcJ{$-Oi7KaA?zt62FnC?P;#i zCS2@i91Majt}!_l8%yQTd!GNKZf^--p=y{Ck0|aR!z_Ni4sP%bfM=QP!_>)fG)QR* zshECe!nuU?tuntvE2hkJWr=KdWm#gqkcYxL=I7{4GU%;68fDC6n~G($x0-KZAA3Fs zUyoBg`$3XDXyK+a<sJ~iit#a#=C1Iu42_m9#jjR(W52rh(z=?RU3||#J@n%`l9Rd) z_KOIQpI0EH#E^cH>Jf`Qu;WRaUEjZC?XOTAk>t7DhTjwDmq%liIZ5JNfd$UtQh8I> zg}i@)8!K-sS|+++G=;ovE99pJF^OH073r}oH3q|h?IGqz_>2^oJ?eld;kI99IkH*} zdLSz2Ev(M26?Ppq2WC6qXw{x)*Z<+4{{4&pZ2!mGT?<7csK!x}c=(<RhOn0~Xn2=a zvapK7syD46PJ*asv(^yxsYclHHukd5S}2WdwUXU!4GoYlu-{zhF8;64L6=S=036hs zAbg^FqX(wCjJmPSY_E#pTL|C&me2her^jWBZ#4Z$wQAt~XSk+}2n1F%3j`f?&O55G zp119XP0vdf(fFlbtiAWF<+&!o_jk!sqwdv0e0#T)EPT}?*|%;9Ga8y5PrSfS{luh8 zY0N48W;0scN}{F)1bQZEBMAEAgmD#|<`T2p%eqXOtx!7Lr=?e3MK1*=Kpno6BP=vq z?gFV?GyAC!4!ED%mAwd>JxcWUKfYV!X(d_n!YKRdPA&WpGBtzbqN7dU5V>3`k)Dlq z<%E_=W3JhKo5ceb4HnF3xIbg@7cBmU#owXGmP8LLJjDx0I|yT>9I44T2{AmE1gUCI zBzR*3ErPU{GW6RQuPzh^(@?VI$BFTFe3iT^D(b~Z`h_8naorzd3|9@D=B}SfYO+s7 zijR2$Q}dw-Hp0Bi&i7bo2*|j@<=W{|>Lga>mv#BN2)DJqYVZjpu#<|Kpsu!s@4(U- z5VCK#V5pwo;3sC;SD*XFqP1_-Vg@nqTln~4w_(0RktXyk!#E5WMpV(`40Qb;>M8e^ z2S&@oeZ#%@KRi5IzEWN+Z<LRf2g~JQryubTo<JNkg?MG8bPZNF<($Md>>O9x{{pvf BdmsP+ delta 4033 zcmaJ^?Qfe`6+icJ?8LDXza)0-IBnh&Crz3*t6J8UG^{|ov2M%FXg0U2as50=oH}u8 zKk3%VmM5i~%6KW;ox~(wqzOKOgcud@VMyZ-;LE;%KuCa)5Q2}uXE?uOJ9V}R+t2Sg z_ug~wJ?GqW@9+BO*Z=lh%JZ+D?yvpn-&gb0{PgmBnV(ARt<;%%2iX?c9y*Oz2fw^z z`x>uKI-ThvTanG^r+XG<y2<8WfbJ>E1jr7O-6Jv!y<~?(?jbu&c0^>ZWH~Bwi0nSP zN0i~3ei0&M*JC0@$&QQANA`dS{bVOZh>;T~J4v@JDn~4PkTL`0bkbIW_KB{uO-afO zlAR)Zhz2q#X&WLZMf+^y5h)%e=Lm|>9;VC)IUTe$ivF4U7&&9KHIBlA-;i=b<V=t= z40Q-e2=)=OkJ7%Su6rmmNtr3iOp_BPXOf(0uBS{|hMXaLnCua<d2FRJv*bjYI?gQF zW0c(I{>NIX<4|3<Cq#FSoT%tdVaTSsGbg>+8i#FeS+n0onRyxHxVW%D&H`<nkeZXI zVT;uoFWbrsJA~z1O;4KZ){6<ni@uikFnamv<mVzk71w;nyX38U=+b5Ld++a5$h_~n z-|1T1=B}Hi&K1>dzSVg*+5_bv1F4`IQhmy6l3g3;y0}6I0BB2NcpYJI(U33gRNUN- ztNRYx`k@d93>X#2n3zAJ;%3_a)j(T+^MOAVfR(yK)XiV~8Sfa2DYM-D#<`=Y90rU4 z+B%{*${@3E=U&O~mP*Ame9<yooafw9!=@PcFbo3=Nq`tVZR)-AW;6KsWicNv?G|%+ zXR}naowAE~trtr<_pwEdP^4j*3)n8%j*Ra5xpd2McC)+X(oO9X|1S2#)%e3qPgunc zPguZuXv}(CrOi)z=fjv+&j4ltJ?5XiClkk5bTMIJJGbFv%g#oDgR1AusnC0B;o)yX z?^r5mZbd)z!opRv(6@P39O&laPHx-DR`R7?M<=+bhZtOVo3m|=cs7CsFUn4FPfwYT z`i}Z=Lg+#BpT1M$%Pgh`QI_z{u_RuWZf<2a%DLT5ecC+TzoO2X@AvP`e-?^M(1G?^ zc0GSmq??fR1vK~=?s9w_+bHjv`F1R-mLC2n_E&4>G`AGcPnM$~3y+o6X8@l8JPtT( zevvriI|uQ+<OTeGmHi$th2*g?rq{CI65qJ_Ve&$KnJd##v`CCU2k9Un`Uw;dqM@-X z^c3LpfG+?VA%XNoz$L(Cz>|PwhO{K*L6o$_;wiMArWYB%1o3IW6~HrqX93Rv&H`2d z8n6m@9`I$rRlpkH1%RaM4M<DogTawl24W6y6M#7DYm$6VrJ~9)FQu;3k$cFh7(Z~Z z@;;B^5^eGErf_6D_^zkc)vB?i<_gvPQX>q(c~kYG-tM3cuJI!Kw}dY2DT;B!K3TPH z>ENR(^MoI4!6#uYK=n17_&{n-7?9zUR11oBCsh?yxzeIq4;|7Bisq|S>&2F4w!u%R zW}^d5wGi@=nP)?ssD)*~ZrTcv8>A{xHH2ZS^)L;<p0L;x!H%j%WPoU^M=w>QRAt9l zjhif0>pN^qRBU;dYW-phTOo!pt2J$ji7i}g*b;Az%(m3o7#=+)#>Crg84z0r#Fj*> z2iuZ3Y)ev_LR3wHJ?)q*0b3duO9u~IDx*D^RW)YC1l5LkehyxYszcI^hZvMN91%y_ z>xD)it!tX9oire8tE}45PpC83@!$t(FndUSSkyaY);h<PryV`4KYU1k1mS~;jqr_% z$DA9yR3j3%Q5rx3aTTW`2>GDQlQ3m-xi;1cVSuV*vX-*s_1bu=fnz(~itU(0Z{l$D z_%guTVTEdwaDWZs7$D)>JmtzpP*iP7tY(dd)zdAjC#X6lR!_54>1W%f;Oq=jH7zSX zBkFt+Vs%~<ixfQM@+AC5xjF;gb&FOT8-w@JYWkSDHZml8y0FNbS>vqV;D<k5%T3`z zQ}}gxSr}5z?UzesJA22;SA@@QXPxbvj?E|3Cy(B?0~f8rUH|1z;Xa?MZkX%po5HSH zyIk7Mt`~Ez9Q}nkZWn?V4qj_Amd<hyp`MCsH!Y;5g|EHRviZqn{gsy9TK2ZE#imRk zUDxxAE_k^?g7xgWlXFGqwup^kf9ke|z26kVZ+=S6$LRp<Ad?iG9UbC=hL3V_X&my8 z_+k_d7l{w^25UUx3L%GOZg#;lOtM+Gl+BllMJMkvZKrVp>l1*J0MKcJC~lbU@uB(x zmt~o9JC!EA!Y4fd*aQ>+w*gqQHd6WjeDXY$52!?pYb@Sk`|6orj*q^Shbku-Nrvzu z2zP^GgvvLw?)Gj|0=*;iHlZQ|=~7D%<iB`8{)_FT^pA<?4a0gB&>C4$6;&w}S1Fbf z9sH<}n&d0yD9TMar6#!K<8n+59y}q>xa#vHL^Gt~R>&H%qE=Xi&D)bdRuQu_HIV{8 zON#JSG@o<RuvcFLESlQXq*^dPn0grah$|}Kcv2=e{RUsC{ysg=BgE5-gIF`U>}elM zjlSQ2^i84IU!{BNf8b&Tt)xz<36;n7qk|vZ283)8`UdwgB|%_HlJlaU$(R3mF-Zn$ zmbCEECRg~HM~Zy1HR`=M^=uC%(cOomfq$!FnCBTCG;htsLdXtEcQVL5$NX_7aRM7x zmeMWnId<WWejCl7=F~fAcpXl~VSKSvsr;GWO~=OhkHODo-`v7|RHp&B;m8Wh1+(Jk zB3}y^IM8<%#S)D>iqz-aa-%+t<4QX1>}~Jfm46vKyRzRTm0V}1QY!O3`(+(5?;X4P z#AVi(`*z!XO|ELP>E_XT3UD3Z0A2%p9e^a$@0yEq_v;)q(t8Xt_a(Wxo<|Y+kgH^E zXFa!9bafeeAKReF78l8};ka#^x#2&8-MG&#in|i(M#W<yuR4`V!JMD}-f?WvASZTE kg)EQilyA2u<PUm-9sCUjL;h+1asQIv@AvtACUSh@fAI>qU;qFB diff --git a/classifiers/train_classifiers.py b/classifiers/train_classifiers.py index 2fa365e..3bf5366 100644 --- a/classifiers/train_classifiers.py +++ b/classifiers/train_classifiers.py @@ -9,14 +9,18 @@ import object_classifiers.train_obj_classifier as obj_trainer import object_classifiers.eval_obj_classifier as obj_evaluator import attribute_classifiers.train_atr_classifier as atr_trainer import attribute_classifiers.eval_atr_classifier as atr_evaluator +import answer_classifier.train_ans_classifier as ans_trainer workflow = { 'train_obj': False, 'eval_obj': False, 'train_atr': False, - 'eval_atr': True, + 'eval_atr': False, + 'train_ans': True, } +ans_mode = ['q'] + obj_classifier_train_params = { 'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier', 'adam_lr': 0.0001, @@ -55,6 +59,24 @@ atr_classifier_eval_params = { 'html_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/html_dir', } +ans_classifier_train_params = { + 'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json', + 'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json', + 'regions_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/regions_anno.json', + 'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images', + 'image_regions_dir': '/mnt/ramdisk/image_regions', + 'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier', + 'obj_atr_model': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/obj_atr_classifier-1', + 'adam_high_lr' : 0.0001, + 'adam_low_lr' : 0.0000,#1, + 'mode' : 'q_reg', + 'crop_n_save_regions': False, + 'max_epoch': 10, + 'batch_size': 20, + 'fine_tune': False, + 'start_model': 9, +} + if __name__=='__main__': if workflow['train_obj']: obj_trainer.train(obj_classifier_train_params) @@ -67,3 +89,6 @@ if __name__=='__main__': if workflow['eval_atr']: atr_evaluator.eval(atr_classifier_eval_params) + + if workflow['train_ans']: + ans_trainer.train(ans_classifier_train_params) diff --git a/shapes_dataset/evaluate_shapes_test.py b/shapes_dataset/evaluate_shapes_test.py index 1a53663..f75b659 100644 --- a/shapes_dataset/evaluate_shapes_test.py +++ b/shapes_dataset/evaluate_shapes_test.py @@ -7,7 +7,7 @@ if __name__== "__main__": res_data = json.load(f_res); anno_data = json.load(f_anno); - + print(len(anno_data)) assert(len(res_data) == len(anno_data)) res_dict = dict() # convert to map with qid as key -- GitLab