From b9351f74614455ef144226426a3d973233672f39 Mon Sep 17 00:00:00 2001 From: tgupta6 <tgupta6@illinois.edu> Date: Wed, 16 Mar 2016 18:36:02 -0500 Subject: [PATCH] reset the graph at the end of obj classifier training --- .../train_atr_classifier.py | 80 +++++++++++++ classifiers/object_classifiers/__init__.py | 1 + classifiers/object_classifiers/__init__.pyc | Bin 161 -> 161 bytes .../object_classifiers/eval_obj_classifier.py | 56 +++++++++ .../eval_obj_classifier.pyc | Bin 2677 -> 2671 bytes .../train_obj_classifier.py | 61 ++++++++++ .../train_obj_classifier.pyc | Bin 3095 -> 3113 bytes classifiers/plot_helper.py | 33 ++++++ classifiers/plot_helper.pyc | Bin 1490 -> 1490 bytes classifiers/tf_graph_creation_helper.py | 109 ++++++++++++++++++ classifiers/tf_graph_creation_helper.pyc | Bin 5702 -> 5656 bytes classifiers/train_classifiers.py | 40 +++++++ 12 files changed, 380 insertions(+) create mode 100644 classifiers/attribute_classifiers/train_atr_classifier.py create mode 100644 classifiers/object_classifiers/__init__.py create mode 100644 classifiers/object_classifiers/eval_obj_classifier.py create mode 100644 classifiers/object_classifiers/train_obj_classifier.py create mode 100644 classifiers/plot_helper.py create mode 100644 classifiers/tf_graph_creation_helper.py create mode 100644 classifiers/train_classifiers.py diff --git a/classifiers/attribute_classifiers/train_atr_classifier.py b/classifiers/attribute_classifiers/train_atr_classifier.py new file mode 100644 index 0000000..dd490ad --- /dev/null +++ b/classifiers/attribute_classifiers/train_atr_classifier.py @@ -0,0 +1,80 @@ +import sys +import os +import matplotlib.pyplot as plt +import matplotlib.image as mpimg +import numpy as np +import tensorflow as tf +import atr_data_io_helper as atr_data_loader + +def train(): + # Start session + sess = tf.InteractiveSession() + + x, y, keep_prob = placeholder_inputs() + y_pred = comp_graph_v_2(x, y, keep_prob) + + # Specify loss + cross_entropy = -tf.reduce_sum(y*tf.log(y_pred)) + + # Specify training method + train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy) + + # Evaluator + accuracy = evaluation(y, y_pred) + + # Merge summaries and write them to ~/Code/Tensorflow_Exp/logDir + merged = tf.merge_all_summaries() + + # Output dir + outdir = '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier_v_1/' + if not os.path.exists(outdir): + os.mkdir(outdir) + + # Training Data + img_width = 75 + img_height = 75 + train_json_filename = '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json' + image_dir = '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images' + mean_image = atr_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width) + np.save(os.path.join(outdir, 'mean_image.npy'), mean_image) + + # Val Data + val_batch = atr_data_loader.atr_mini_batch_loader(train_json_filename, image_dir, mean_image, 9501, 499, img_height, img_width) + feed_dict_val={x: val_batch[0], y: val_batch[1], keep_prob: 1.0} + + # Session Saver + saver = tf.train.Saver() + + # Start Training + sess.run(tf.initialize_all_variables()) + batch_size = 100 + max_epoch = 10 + max_iter = 95 + val_acc_array_iter = np.empty([max_iter*max_epoch]) + val_acc_array_epoch = np.zeros([max_epoch]) + train_acc_array_epoch = np.zeros([max_epoch]) + for epoch in range(max_epoch): + for i in range(max_iter): + train_batch = atr_data_loader.atr_mini_batch_loader(train_json_filename, image_dir, mean_image, 1+i*batch_size, batch_size, img_height, img_width) + feed_dict_train={x: train_batch[0], y: train_batch[1], keep_prob: 0.5} + + _, current_train_batch_acc = sess.run([train_step, accuracy], feed_dict=feed_dict_train) + + train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] + current_train_batch_acc + val_acc_array_iter[i+epoch*max_iter] = accuracy.eval(feed_dict_val) + print('Step: {} Val Accuracy: {}'.format(i+1+epoch*max_iter, val_acc_array_iter[i+epoch*max_iter])) + plot_accuracy(np.arange(0,i+1+epoch*max_iter)+1, val_acc_array_iter[0:i+1+epoch*max_iter], xlim=[1, max_epoch*max_iter], ylim=[0, 1.], savePath=os.path.join(outdir,'valAcc_vs_iter.pdf')) + + + train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] / max_iter + val_acc_array_epoch[epoch] = val_acc_array_iter[i+epoch*max_iter] + + plot_accuracies(xdata=np.arange(0,epoch+1)+1, ydata_train=train_acc_array_epoch[0:epoch+1], ydata_val=val_acc_array_epoch[0:epoch+1], xlim=[1, max_epoch], ylim=[0, 1.], savePath=os.path.join(outdir,'acc_vs_epoch.pdf')) + + save_path = saver.save(sess, os.path.join(outdir,'obj_classifier_{}.ckpt'.format(epoch))) + + + sess.close() + +if __name__=='__main__': + train() diff --git a/classifiers/object_classifiers/__init__.py b/classifiers/object_classifiers/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/classifiers/object_classifiers/__init__.py @@ -0,0 +1 @@ + diff --git a/classifiers/object_classifiers/__init__.pyc b/classifiers/object_classifiers/__init__.pyc index f3f2c48cd84fe79156e9feb76102ac9af6b5e2cf..e06a3d60937fbfd294f136300cbbcbe32f99b3f8 100644 GIT binary patch delta 15 XcmZ3;xR8;Z`7<xq#LF)yvd;nlDD4H< delta 15 XcmZ3;xR8;Z`7<vU<Ee)e*=GR&B|rr? diff --git a/classifiers/object_classifiers/eval_obj_classifier.py b/classifiers/object_classifiers/eval_obj_classifier.py new file mode 100644 index 0000000..6c78aa4 --- /dev/null +++ b/classifiers/object_classifiers/eval_obj_classifier.py @@ -0,0 +1,56 @@ +import sys +import os +import matplotlib.pyplot as plt +import matplotlib.image as mpimg +import numpy as np +from scipy import misc +import tensorflow as tf +import object_classifiers.obj_data_io_helper as shape_data_loader +import tf_graph_creation_helper as graph_creator +import plot_helper as plotter + +def eval(eval_params): + sess = tf.InteractiveSession() + + x, y, keep_prob = graph_creator.placeholder_inputs() + y_pred = graph_creator.obj_comp_graph(x, keep_prob) + accuracy = graph_creator.evaluation(y, y_pred) + + saver = tf.train.Saver() + saver.restore(sess, eval_params['model_name'] + '-' + str(eval_params['global_step'])) + mean_image = np.load(os.path.join(eval_params['out_dir'], 'mean_image.npy')) + test_json_filename = eval_params['test_json'] + image_dir = eval_params['image_dir'] + html_dir = eval_params['html_dir'] + if not os.path.exists(html_dir): + os.mkdir(html_dir) + html_writer = shape_data_loader.html_obj_table_writer(os.path.join(html_dir, 'index.html')) + col_dict = {0: 'Grount Truth', + 1: 'Prediction', + 2: 'Image'} + html_writer.add_element(col_dict) + shape_dict = {0: 'blank', + 1: 'rectangle', + 2: 'triangle', + 3: 'circle'} + batch_size = 100 + correct = 0 + for i in range(1): + test_batch = shape_data_loader.obj_mini_batch_loader(test_json_filename, image_dir, mean_image, 10000 + i * batch_size, batch_size, 75, 75) + feed_dict_test = {x: test_batch[0], y: test_batch[1], keep_prob: 1.0} + result = sess.run([accuracy, y_pred], feed_dict=feed_dict_test) + correct = correct + result[0] * batch_size + print correct + for row in range(batch_size * 9): + gt_id = np.argmax(test_batch[1][row, :]) + pred_id = np.argmax(result[1][row, :]) + if not gt_id == pred_id: + img_filename = os.path.join(html_dir, '{}_{}.png'.format(i, row)) + misc.imsave(img_filename, test_batch[0][row, :, :, :]) + col_dict = {0: shape_dict[gt_id], + 1: shape_dict[pred_id], + 2: html_writer.image_tag('{}_{}.png'.format(i, row), 25, 25)} + html_writer.add_element(col_dict) + + html_writer.close_file() + print 'Test Accuracy: {}'.format(correct / 5000) diff --git a/classifiers/object_classifiers/eval_obj_classifier.pyc b/classifiers/object_classifiers/eval_obj_classifier.pyc index 5e07c1a44e9e386d228fd7b28919f8eba3a8079a..c4e6f97b11143b781aef4f2a8c5d37ac1afe9cfa 100644 GIT binary patch delta 98 zcmew=@?M0E`7<w9*2|4-i5$K`3_!rh!z9Wm#3;nX<z#K;APLW}~8azGY45c7g* l5Q_~6g&8#&l^LahI#?OSfh=W49!5n*Ge*A6yEt|;0sy7;3pfA( delta 104 zcmaDa@>PV5`7<w<+0Kn@i5$TZ3_!rh!z9Wm#4N<b&!onr#LUIW1;j#(0?cxZ%#7?n u$i>LZ1YrT$Ol*v7jKYkXjLMACKvk?v;y{)%BM+k@qZv>=_vS4eyBPtu*9z(Y diff --git a/classifiers/object_classifiers/train_obj_classifier.py b/classifiers/object_classifiers/train_obj_classifier.py new file mode 100644 index 0000000..292e1ca --- /dev/null +++ b/classifiers/object_classifiers/train_obj_classifier.py @@ -0,0 +1,61 @@ +import sys +import os +import matplotlib.pyplot as plt +import matplotlib.image as mpimg +import numpy as np +import tensorflow as tf +import object_classifiers.obj_data_io_helper as shape_data_loader +import tf_graph_creation_helper as graph_creator +import plot_helper as plotter + +def train(train_params): + sess = tf.InteractiveSession() + + x, y, keep_prob = graph_creator.placeholder_inputs() + y_pred = graph_creator.obj_comp_graph(x, keep_prob) + cross_entropy = graph_creator.loss(y, y_pred) + train_step = tf.train.AdamOptimizer(train_params['adam_lr']).minimize(cross_entropy) + accuracy = graph_creator.evaluation(y, y_pred) + + outdir = train_params['out_dir'] + if not os.path.exists(outdir): + os.mkdir(outdir) + + img_width = 75 + img_height = 75 + train_json_filename = train_params['train_json'] + image_dir = train_params['image_dir'] + mean_image = shape_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width) + np.save(os.path.join(outdir, 'mean_image.npy'), mean_image) + val_batch = shape_data_loader.obj_mini_batch_loader(train_json_filename, image_dir, mean_image, 9501, 499, img_height, img_width) + feed_dict_val = {x: val_batch[0], y: val_batch[1], keep_prob: 1.0} + + saver = tf.train.Saver() + + sess.run(tf.initialize_all_variables()) + batch_size = 10 + max_epoch = 2 + max_iter = 1 + val_acc_array_iter = np.empty([max_iter * max_epoch]) + val_acc_array_epoch = np.zeros([max_epoch]) + train_acc_array_epoch = np.zeros([max_epoch]) + for epoch in range(max_epoch): + for i in range(max_iter): + train_batch = shape_data_loader.obj_mini_batch_loader(train_json_filename, image_dir, mean_image, 1 + i * batch_size, batch_size, img_height, img_width) + feed_dict_train = {x: train_batch[0], y: train_batch[1], keep_prob: 0.5} + _, current_train_batch_acc = sess.run([train_step, accuracy], feed_dict=feed_dict_train) + train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] + current_train_batch_acc + val_acc_array_iter[i + epoch * max_iter] = accuracy.eval(feed_dict_val) + plotter.plot_accuracy(np.arange(0, i + 1 + epoch * max_iter) + 1, val_acc_array_iter[0:i + 1 + epoch * max_iter], xlim=[1, max_epoch * max_iter], ylim=[0, 1.0], savePath=os.path.join(outdir, 'valAcc_vs_iter.pdf')) + print 'Step: {} Val Accuracy: {}'.format(i + 1 + epoch * max_iter, val_acc_array_iter[i + epoch * max_iter]) + + train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] / max_iter + val_acc_array_epoch[epoch] = val_acc_array_iter[i + epoch * max_iter] + plotter.plot_accuracies(xdata=np.arange(0, epoch + 1) + 1, ydata_train=train_acc_array_epoch[0:epoch + 1], ydata_val=val_acc_array_epoch[0:epoch + 1], xlim=[1, max_epoch], ylim=[0, 1.0], savePath=os.path.join(outdir, 'acc_vs_epoch.pdf')) + save_path = saver.save(sess, os.path.join(outdir, 'obj_classifier'), global_step=epoch) + + sess.close() + tf.reset_default_graph() + +if __name__ == '__main__': + train() diff --git a/classifiers/object_classifiers/train_obj_classifier.pyc b/classifiers/object_classifiers/train_obj_classifier.pyc index 2891e81874bd05649b01ab33a1514798af941eb8..d803fd2f3ff49fa7edbd7c3cee464cb231ef703c 100644 GIT binary patch delta 162 zcmbO(u~LGK`7<w<z{`njXBppYyz-ElwS<8oOL6lCmMBI>rOn)IhZ!YGgc%qZic*VH zOX5>f(-KQ_O5)Rt5(_dmpJUfy3JCy8GBENmi82ZSp**7y6Bm#cU}9rrW8?xtIiQ## yqcRgelQ2*O1jWEAg&BnzH5rwGazc#Kj0KFkOhQ1JVn%JC{>^b*+ZdS*c^Cm?OC6a2 delta 147 zcmZ1}F<pX<`7<w9%Fc;wXBi)Dyz-EF^M004Mn=WWqHKp5H-BK)VM<A000Jf+CQ(Kq zW+7&IAm(BeViaIzV`O9G0zx@Pd1ggMWoCY6VW0>IiUH+?7=?k5laUz+xxpkSlQN?s iP+Xc(l~IyWhEa@3mq`eywwOtqd2<2RHb!PI9!3BgHWbGI diff --git a/classifiers/plot_helper.py b/classifiers/plot_helper.py new file mode 100644 index 0000000..2697308 --- /dev/null +++ b/classifiers/plot_helper.py @@ -0,0 +1,33 @@ +#Embedded file name: /home/tanmay/Code/GenVQA/GenVQA/classifiers/plot_helper.py +import matplotlib.pyplot as plt +import matplotlib.image as mpimg +import numpy as np +import tensorflow as tf + +def plot_accuracy(xdata, ydata, xlim = None, ylim = None, savePath = None): + fig, ax = plt.subplots(nrows=1, ncols=1) + ax.plot(xdata, ydata) + plt.xlabel('Iterations') + plt.ylabel('Accuracy') + if not xlim == None: + plt.xlim(xlim) + if not ylim == None: + plt.ylim(ylim) + if not savePath == None: + fig.savefig(savePath) + plt.close(fig) + + +def plot_accuracies(xdata, ydata_train, ydata_val, xlim = None, ylim = None, savePath = None): + fig, ax = plt.subplots(nrows=1, ncols=1) + ax.plot(xdata, ydata_train, xdata, ydata_val) + plt.xlabel('Epochs') + plt.ylabel('Accuracy') + plt.legend(['Train', 'Val'], loc='lower right') + if not xlim == None: + plt.xlim(xlim) + if not ylim == None: + plt.ylim(ylim) + if not savePath == None: + fig.savefig(savePath) + plt.close(fig) diff --git a/classifiers/plot_helper.pyc b/classifiers/plot_helper.pyc index 12089f68ec578fbb6a806c48bb6e9c393c472c7c..00bed5ede2ae422dae5734796e725a75a163c354 100644 GIT binary patch delta 93 zcmcb_eTkc$`7<w<)#aBP*{3k^GV(AA!13l4OoB{`q6`cS#S%b*fl-c8jFA@zCE-dy Z$|v7qc`D2VQpE)%_!;?uvOG*ed;kLL47LCO delta 93 zcmcb_eTkc$`7<w<fYsxT>{FO{nRpllV3>LH3MN4&MM(w*hGGdI!N4fTD8|SOgp!O% a%9tkKVtFde2vWrbB={NmfwDZzLP7w<6bt<T diff --git a/classifiers/tf_graph_creation_helper.py b/classifiers/tf_graph_creation_helper.py new file mode 100644 index 0000000..43d0c6e --- /dev/null +++ b/classifiers/tf_graph_creation_helper.py @@ -0,0 +1,109 @@ +import numpy as np +import tensorflow as tf + +def weight_variable(shape, var_name = 'W'): + initial = tf.truncated_normal(shape, stddev=0.1) + return tf.Variable(initial, name=var_name) + + +def bias_variable(shape, var_name = 'b'): + initial = tf.constant(0.1, shape=shape) + return tf.Variable(initial, name=var_name) + + +def conv2d(x, W, var_name = 'W'): + return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME', name=var_name) + + +def max_pool_2x2(x, var_name = 'h_pool'): + return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME', name=var_name) + + +def placeholder_inputs(mode = 'gt'): + x = tf.placeholder(tf.float32, shape=[None,25,25,3]) + keep_prob = tf.placeholder(tf.float32) + if mode == 'gt': + print 'Creating placeholder for ground truth' + y = tf.placeholder(tf.float32, shape=[None, 4]) + return (x, y, keep_prob) + if mode == 'no_gt': + print 'No placeholder for ground truth' + return (x, keep_prob) + + +def obj_comp_graph(x, keep_prob): + with tf.name_scope('obj') as obj_graph: + with tf.name_scope('conv1') as conv1: + W_conv1 = weight_variable([5,5,3,4]) + b_conv1 = bias_variable([4]) + h_conv1 = tf.nn.relu(conv2d(x, W_conv1) + b_conv1, name='h') + h_pool1 = max_pool_2x2(h_conv1) + h_conv1_drop = tf.nn.dropout(h_pool1, keep_prob, name='h_pool_drop') + with tf.name_scope('conv2') as conv2: + W_conv2 = weight_variable([3,3,4,8]) + b_conv2 = bias_variable([8]) + h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2, name='h') + h_pool2 = max_pool_2x2(h_conv2) + h_pool2_drop = tf.nn.dropout(h_pool2, keep_prob, name='h_pool_drop') + h_pool2_drop_flat = tf.reshape(h_pool2_drop, [-1, 392], name='h_pool_drop_flat') + with tf.name_scope('fc1') as fc1: + W_fc1 = weight_variable([392, 4]) + b_fc1 = bias_variable([4]) + y_pred = tf.nn.softmax(tf.matmul(h_pool2_drop_flat, W_fc1) + b_fc1) + tf.add_to_collection('obj_feat', h_pool2_drop_flat) + return y_pred + + +def atr_comp_graph(x, keep_prob, obj_feat): + with tf.name_scope('atr') as obj_graph: + with tf.name_scope('conv1') as conv1: + W_conv1 = weight_variable([5,5,3,4]) + b_conv1 = bias_variable([4]) + h_conv1 = tf.nn.relu(conv2d(x, W_conv1) + b_conv1, name='h') + h_pool1 = max_pool_2x2(h_conv1) + h_conv1_drop = tf.nn.dropout(h_pool1, keep_prob, name='h_pool_drop') + with tf.name_scope('conv2') as conv2: + W_conv2 = weight_variable([3,3,4,8]) + b_conv2 = bias_variable([8]) + h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2, name='h') + h_pool2 = max_pool_2x2(h_conv2) + h_pool2_drop = tf.nn.dropout(h_pool2, keep_prob, name='h_pool_drop') + h_pool2_drop_flat = tf.reshape(h_pool2_drop, [-1, 392], name='h_pool_drop_flat') + with tf.name_scope('fc1') as fc1: + W_obj_fc1 = weight_variable([392, 4], var_name='W_obj') + W_atr_fc1 = weight_variable([392, 4], var_name='W_atr') + b_fc1 = bias_variable([4]) + y_pred = tf.nn.softmax(tf.matmul(h_pool2_drop_flat, W_atr_fc1) + tf.matmul(obj_feat, W_obj_fc1) + b_fc1) + tf.add_to_collection('atr_feat', h_pool2_drop_flat) + return y_pred + + +def evaluation(y, y_pred): + correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_pred, 1), name='correct_prediction') + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy') + return accuracy + + +def loss(y, y_pred): + cross_entropy = -tf.reduce_sum(y * tf.log(y_pred), name='cross_entropy') + return cross_entropy + + +if __name__ == '__main__': + lg_dir = '/home/tanmay/Code/GenVQA/Exp_Results/lg_files/' + g = tf.Graph() + with g.as_default(): + x, y, keep_prob = placeholder_inputs(mode='gt') + y_pred = obj_comp_graph(x, keep_prob) + obj_feat = tf.get_collection('obj_feat', scope='obj/conv2') + y_pred2 = atr_comp_graph(x, keep_prob, obj_feat[0]) + accuracy = evaluation(y, y_pred2) + accuracy_summary = tf.scalar_summary('accuracy', accuracy) + sess = tf.Session() + sess.run(tf.initialize_all_variables()) + merged = tf.merge_all_summaries() + summary_writer = tf.train.SummaryWriter(lg_dir, graph_def=g.as_graph_def()) + result = sess.run([merged, y_pred], feed_dict={x: np.random.rand(10, 25, 25, 3), + y: np.random.rand(10, 4), + keep_prob: 1.0}) + summary_writer.add_summary(result[0], 1) diff --git a/classifiers/tf_graph_creation_helper.pyc b/classifiers/tf_graph_creation_helper.pyc index 2010602d68641b11f6d00b6ea9a5a1057af52cd9..ba8c1623b0ba07ff0da1f5e5eceedd6f69bfefaa 100644 GIT binary patch delta 360 zcma)2u}Z{15ZxD(oEtXTupxI*A>y5)=z@A$L{Mx5v9u5?u`nPAR%&TwW512ySI8Fr zft8KTfry=rjfMVzvs+opz%cK@dv9i*^quaE{iMm@{qvzS9fI5WFb6W*L`#6OZTyX1 zmoavss4>h!K_Ehn6!v_44-n>&Yw92T7`fR=4uV^l`$XHf<QdC-CtdJTMbZiqb7U~{ zXP)w=3MKlmBeCl&>r77w)<*0!V69|r&TfEf#V!d7bCO-5Ft6FRi<*=}dnVVhw&9`{ z{}ENu8YPD4!oK83ko!&;d*KT0Vz$&*dg!>&a|76>7zLXRApsGHh$^HKyvHjZ^GG13 Tn^J#WNw(5IsUWV#D&Z8rb!#$F delta 408 zcma)1u}T9$6x?~a+~v-_?72zK1R)_BQS?BKDPpCdAXYY7iU<di!b&ZzZ0xrYEc^!9 z!d_4tYkxp%3;hA#UZEBi^4MYK!OWYz${w?z^chDRtFN~~$3w-=dI=!yu)F{$oBJP_ z&_ak42ZDOgkSL>#nCx+78yGitVTPJ{3lC5?o6#Nx!dypLeYDJm(IptZnSF|$vDkHa zb48vejV2V$aBBA#pG1$9MuVov-peg;c2vz;smROXh(xYLo`?(JT#HkdO|w%yL(4o= z*W7(y_1iyJ|6+Ei23T2lP3Hg6q{%HC)Tfl}O|lPl=!|}0T;i%`Ouxq?xAThKG+Dp) mLc;+o2}#o9De<VV6L!N&7)bI-B6gog11Y0sUE|lLU3>vd05|CX diff --git a/classifiers/train_classifiers.py b/classifiers/train_classifiers.py new file mode 100644 index 0000000..6367807 --- /dev/null +++ b/classifiers/train_classifiers.py @@ -0,0 +1,40 @@ +import sys +import os +import matplotlib.pyplot as plt +import matplotlib.image as mpimg +import numpy as np +import tensorflow as tf +import object_classifiers.train_obj_classifier as obj_trainer +import object_classifiers.eval_obj_classifier as obj_evaluator + +workflow = { + 'train_obj': False, + 'eval_obj': False, +} + +train_params = { + 'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier', + 'adam_lr': 0.001, + 'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json', + 'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images', +} + +eval_params = { + 'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier', + 'model_name': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier/obj_classifier', + 'global_step': 1, + 'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json', + 'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images', + 'html_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier/html_dir', + 'create_graph': False, +} + +if __name__=='__main__': + if workflow['train_obj']: + obj_trainer.train(train_params) + + obj_feat = tf.get_collection('obj_feat', scope='obj/conv2') + print(obj_feat) + if workflow['eval_obj']: + obj_evaluator.eval(eval_params) + -- GitLab