From 92f02b56957d36d58b1c5b4d7544de055358ba4a Mon Sep 17 00:00:00 2001 From: tgupta6 <tgupta6@illinois.edu> Date: Fri, 18 Mar 2016 13:17:45 -0500 Subject: [PATCH] optimized atr training and evaluation --- .../atr_data_io_helper.py | 32 +++++---- .../atr_data_io_helper.pyc | Bin 4641 -> 4538 bytes .../eval_atr_classifier.py | 14 +++- .../eval_atr_classifier.pyc | Bin 2804 -> 3079 bytes .../train_atr_classifier.py | 65 +++++++++++++----- .../train_atr_classifier.pyc | Bin 3616 -> 4071 bytes .../train_obj_classifier.py | 7 +- .../train_obj_classifier.pyc | Bin 3118 -> 3261 bytes classifiers/tf_graph_creation_helper.pyc | Bin 5656 -> 5656 bytes classifiers/train_classifiers.py | 8 ++- 10 files changed, 89 insertions(+), 37 deletions(-) diff --git a/classifiers/attribute_classifiers/atr_data_io_helper.py b/classifiers/attribute_classifiers/atr_data_io_helper.py index 0f707a9..572704d 100644 --- a/classifiers/attribute_classifiers/atr_data_io_helper.py +++ b/classifiers/attribute_classifiers/atr_data_io_helper.py @@ -6,47 +6,49 @@ import matplotlib.image as mpimg import numpy as np import tensorflow as tf from scipy import misc +import time -def atr_mini_batch_loader(json_filename, image_dir, mean_image, start_index, batch_size, img_height=100, img_width=100, channels=3): - - with open(json_filename, 'r') as json_file: - json_data = json.load(json_file) +def atr_mini_batch_loader(json_data, image_dir, mean_image, start_index, batch_size, img_height=100, img_width=100, channels=3): atr_images = np.empty(shape=[9*batch_size, img_height/3, img_width/3, channels]) atr_labels = np.zeros(shape=[9*batch_size, 4]) - for i in range(start_index, start_index + batch_size): + for i in xrange(start_index, start_index + batch_size): image_name = os.path.join(image_dir, str(i) + '.jpg') image = misc.imresize(mpimg.imread(image_name),(img_height, img_width), interp='nearest') + #image = np.zeros(shape=[img_height, img_width, channels]) crop_shape = np.array([image.shape[0], image.shape[1]])/3 - selected_anno = [q for q in json_data if q['image_id']==i] - grid_config = selected_anno[0]['config'] - + grid_config = json_data[i] counter = 0; - for grid_row in range(0,3): - for grid_col in range(0,3): + for grid_row in xrange(0,3): + for grid_col in xrange(0,3): start_row = grid_row*crop_shape[0] start_col = grid_col*crop_shape[1] cropped_image = image[start_row:start_row+crop_shape[0], start_col:start_col+crop_shape[1], :] + if np.ndim(mean_image)==0: atr_images[9*(i-start_index)+counter,:,:,:] = cropped_image/254.0 else: atr_images[9*(i-start_index)+counter,:,:,:] = (cropped_image / 254.0) - mean_image - atr_labels[9*(i-start_index)+counter, grid_config[6*grid_row+2*grid_col+1]] = 1 + + if grid_config[6*grid_row+2*grid_col]==0: # if the space is blank + atr_labels[9*(i-start_index)+counter,3] = 1 + else: + atr_labels[9*(i-start_index)+counter, grid_config[6*grid_row+2*grid_col+1]] = 1 counter = counter + 1 return (atr_images, atr_labels) -def mean_image_batch(json_filename, image_dir, start_index, batch_size, img_height=100, img_width=100, channels=3): - batch = atr_mini_batch_loader(json_filename, image_dir, np.empty([]), start_index, batch_size, img_height, img_width, channels) +def mean_image_batch(json_data, image_dir, start_index, batch_size, img_height=100, img_width=100, channels=3): + batch = atr_mini_batch_loader(json_data, image_dir, np.empty([]), start_index, batch_size, img_height, img_width, channels) mean_image = np.mean(batch[0], 0) return mean_image -def mean_image(json_filename, image_dir, num_images, batch_size, img_height=100, img_width=100, channels=3): +def mean_image(json_data, image_dir, num_images, batch_size, img_height=100, img_width=100, channels=3): max_iter = np.floor(num_images/batch_size) mean_image = np.zeros([img_height/3, img_width/3, channels]) for i in range(max_iter.astype(np.int16)): - mean_image = mean_image + mean_image_batch(json_filename, image_dir, 1+i*batch_size, batch_size, img_height, img_width, channels) + mean_image = mean_image + mean_image_batch(json_data, image_dir, 1+i*batch_size, batch_size, img_height, img_width, channels) mean_image = mean_image/max_iter return mean_image diff --git a/classifiers/attribute_classifiers/atr_data_io_helper.pyc b/classifiers/attribute_classifiers/atr_data_io_helper.pyc index b958afac0eda1f2f3ff3b1e2e2dc36c210e170a5..457ee8ea56d594f13e29953a5ac035b031255cd4 100644 GIT binary patch delta 1518 zcmZux&2Jl35dXdPN9^_5>yLH(k&w`aHYP|kpw=iQh_opPsffcOkzx^9CcAOg*lWjL zgoAcvgv%8Py&IgkfW!qM4oIAO=z$vt{sgXY;)K)#GtW_J)%v}iH*aR%%x~t+^S96a zy0H@bC#yfW@z;AtIR7PR-J@?({c6RZ0ggc5M9YK|Q!@@Hfx9uZ@{dT8N>X@4vZN$M zI4L+JPvb5y2i|X$;978We57G)3AY3%gSkMf3{Qh&z{#TNmeE^=lfy-a;+mLK=scRO z3RqKxTj3*uH#NB?uFlE?Sf_8FfuCwX0Q3Z7cQk}?SYd*(mBZv>Swc5Xa72_+mK6A* zr4vV2Q@S*!Yg!OPxR`NN<{EAtrnN{%5JxY8UJ{`|urx!cGa{-C+e8@zDfH3^bo4R^ zRf2?oq90Qz!mxz7CL%)xHxL@=Ww}zdb1sb<V<3lwIm#jg2dG6kHi}G4=Ynu?X=CIF zw`PZ95#=%U9v2_p5m5nQ9^t~f6%`SZfOZtYib@<gBpJ<xMdb@B7l~;uj9B1M69B{< z#|VR^K<5xH0H*jN>*F6psBX0!5-9SabJeO=Xv)`Om+I##4b`WpFU&&9i^`OTMfXnt ztxAn@;|95)5z;XIxJy(hZ7Srhzs-p}seaNQnN@0hB&2>Ko@#g+69=b%<KVPa!~ww~ zH)QZ`BTFsxj{G;aztbfSxHsJMlC)BH0u7T$z??q{_@%gmQE;ZXPu=lw60qe|ejmT{ ztYFJHTRPsL!3`RuiS#BzzwLB_j@(Fe<#&m<rJnrrN`^F(Eh5mi#Hv;o>ta*e6uL-? zWg0me*R-;DRom58#kwZfmWH=4leNtOYzCr24cw03;3#&1Ns)=ggu~m*Oe##OL=@-` zJHgS3U6MbVp1hGNi!FJO`fbKemHN0p9FA?yNo~g+zw4?LJCopS<SL=(2QM}4=gFqQ zFYy1!)clnE$Ah*v=ycu5Q!%8-WYC$my}%vYB|g3a$e+xU1aGeVoYXIgv`TF_a(%g- z*%U7`*c`ntpxyR-FKD-~%ilBiGFSQdH6|~}uZ>mFl)o5Xir3|v+0QdOByOl$HRT`K zZ^RpNnCsOao=}M#^@kI;!PjOtRY5*5P4SlOn=j^HVc{;5JtjA0RPdAbHt{=hEiWr| z;uK20U3!^xnP(F_CGY2NmY?Ah+*B$E)njyK$K^_)v9C@9-oUlhRdNXzHDJd!|CI&Y x7bn4>-~ME*O0>tC>r50H2J+eEi?lZc50eJ3d0}Y9^<tu^mzE995E~*Z{sR$^1`PlJ delta 1636 zcmZuxO=w(I6#nj;zs$^=pG@BT<gY48>9c6DsHxIalcq?qTJQ;`4c0K7_hy_iGm~VV zkidHcP2EUQ<St!^J6*f$s?d!~m*QGnxGIPR35egh6G92jefPX`zkBXE-#O>Ld9d*P z;%wrtLN52@&ktHS{-yX{;T+lD%nl9`@cAU+r*JQUjgnc8k+ea|PazqDOv0alPnpF% z;1IaIQ9&>TKZj2&>`thl0zZ#Kg$)N?3%(7%fTh4euL{2ipFv7EWRNmyOB*%FP8~r_ zjwo3)^m+Bl!8A$cIoE{jmlhBL5#TEvTKKw#cmkn9&q6dIC*v&Eec-)YVFEoS#J~we zNjOnT60An)@hgK)lB`vngIIF&#F9nmSY|Jey@Rq9rV-JHaZ_i|Yj=)RY$ME|H-TOj zkwTb5WJ`EI4Kaxe2=nOCsDNG(Q9<^|fYIBGlabmTS}HD@;3Y&w^jP97rjNB8riFvT zfsl?%=rHRnL(d^DBcg4dwzettk*q(muycfX5|L@iG|N35vjbVgQ`oz%m;`k?GYMQ$ zv4gz}r_gy7R}ndg*cwuC;u@j@B3zP>6W4`jM3B$er2Yl#PMFWwNC1|f50!R;92zD? z@-peruiTD*Rn(B{P>ph?Kem*&;Y3rYpPU{mI-PK4j%_V8PTT9ht>-=@yC{F!PsHTO z)`7!o7FFTYX;z<R`Ii|+bQWf3U9In|a(;$&x~3QIzwefL*KYDw0lWH-#M+H;f@-(l z>;#)#pOpk!!$G^-VN>{fI9%xp#!C*k6^bxPG29LYM(K@)V`XdD^z~|TE-4?MzMH(Y zC?|}~!0+~5x$3+yfSlNDcejE;vmfZ+lAYP?(`b0I@|$5(ucY3s3)dZUb`)#{tuXL6 zn}fkn-%tI1F~_~p3ly*_s%o8AbJmP{MxEn$UKOphs&N!KUR70f*;-cmyY%qpOSJdo zs3%{aA=wX_16N{ubrB5_Ga_b1%!!y6afTw49Pc*6*2kWtr5~p{`ccMFuKqRi^MUv^ z7VWL!aM!cB^dxap+vX_T-ww=ncMx7&@?PeSYp9&W<p&20KJ9Osu8jVdfi$h(+}rGi z!LA|UEr=sdy_jvKgw~5c-Orv?DK8`7gWS2~Wl@CdJ9%5Zrf=sTy32&TD<TZT1T{9s zFl`Ta`_0f5$a_Pt+Fz+F`WO52{3-!g?f}s{g>TiG{-@BJIdNoCYi$ijfh&EZQC!?p zZ|Nt+3)MFTxF+Jdh_@+R==p3f^*Yr{`iGL966T&{$A2!ic7Nm*wO77cJuCQ>;8J{6 zXM5kv=^x8(%`@GUI*tn=zPuayaoK5@Fjr-5qK*S{X{vBIIl1^O7q*g7;*fXQs+_5$ OD!HkeZP{u;71V$3dmys_ diff --git a/classifiers/attribute_classifiers/eval_atr_classifier.py b/classifiers/attribute_classifiers/eval_atr_classifier.py index c3222b1..6431390 100644 --- a/classifiers/attribute_classifiers/eval_atr_classifier.py +++ b/classifiers/attribute_classifiers/eval_atr_classifier.py @@ -1,4 +1,5 @@ import sys +import json import os import matplotlib.pyplot as plt import matplotlib.image as mpimg @@ -15,6 +16,7 @@ def eval(eval_params): x, y, keep_prob = graph_creator.placeholder_inputs() _ = graph_creator.obj_comp_graph(x, 1.0) obj_feat = tf.get_collection('obj_feat', scope='obj/conv2') + y_pred = graph_creator.atr_comp_graph(x, keep_prob, obj_feat[0]) accuracy = graph_creator.evaluation(y, y_pred) @@ -23,6 +25,13 @@ def eval(eval_params): mean_image = np.load(os.path.join(eval_params['out_dir'], 'mean_image.npy')) test_json_filename = eval_params['test_json'] + with open(test_json_filename, 'r') as json_file: + raw_json_data = json.load(json_file) + test_json_data = dict() + for entry in raw_json_data: + if entry['image_id'] not in test_json_data: + test_json_data[entry['image_id']]=entry['config'] + image_dir = eval_params['image_dir'] html_dir = eval_params['html_dir'] if not os.path.exists(html_dir): @@ -36,12 +45,13 @@ def eval(eval_params): color_dict = { 0: 'red', # blanks are treated as red 1: 'green', - 2: 'blue'} + 2: 'blue', + 3: 'blank'} batch_size = 100 correct = 0 for i in range(50): - test_batch = atr_data_loader.atr_mini_batch_loader(test_json_filename, image_dir, mean_image, 10000+i*batch_size, batch_size, 75, 75) + test_batch = atr_data_loader.atr_mini_batch_loader(test_json_data, image_dir, mean_image, 10000+i*batch_size, batch_size, 75, 75) feed_dict_test={x: test_batch[0], y: test_batch[1], keep_prob: 1.0} result = sess.run([accuracy, y_pred], feed_dict=feed_dict_test) correct = correct + result[0]*batch_size diff --git a/classifiers/attribute_classifiers/eval_atr_classifier.pyc b/classifiers/attribute_classifiers/eval_atr_classifier.pyc index 0be14f65263aa832ead7df9c9f5988c223c072a2..960456aa4a4ea6b14e9af628b1864cfd74c89fa3 100644 GIT binary patch delta 1196 zcmZ`&OK&4Z5dNy|iDSo)_!&>ejuUT)5X%LDB7r!7Hq!2utVqzHRq&>h$P;f~Ziz(d zkwJk)AR!dZ4gLUUICJ7hpuHjSIB`L6<iOVxTLKA2J>6B^UsZio)%|DllXj5#qg43m z-fuhmn0|`u8~n##-305CP(cG)Fv=jxBJ^hJ9HJb;`I(wWSV(1s(E^?n5f&j+%M!x! z3|B$8IFqXgYcsiyumPRHM?lzg@g6>(C5fAF{Bn=$0{?rC)b}^fewdgl>;iNV*9u>} zjkpL?LsWw0jvy|<)Nu{OWn8`tTf)|(yND~W73dOtTgGt?7oIr1oa)n)*UvrS*xe=Z zx8AYLy8=I6gp(F2Rbk4ADicHv*s6;)J`=aqoT_pHm=>a?TOcCZR-p-1DYF`e8E7W$ zVN=(k8(f%e3!8BrmN~RxNnIerEcDVXu6>L1Q=DFgo&?0RS{=kR_7Jyl`L!4Vy3Tg% z(U%gF0eLKg$y$<NZ&qNJV4G+xane0Eu5*anj*?{zkCjl^bYg>IhSUT$%TPY7APQjT zVRJZA6X;b~pRpungcdAGIdkWnWVmzx!DZbsa-r7HNLdjbz!VX6(s(XNkhUNASm98? zCoJyFAoTe^kT<)q^wGt!!uY95don9^9}!QPXpq)iAzp!5LDYj~0j6mN$p@JiC+uMd zzH5c0*rmvJxF)PZ|J#ke1CDF9?ms$p{#0@7p8qbp&Unq;JkjeZj~4WEXfmz)4_pcR z>6rl{ddoW)c+WSISF)dVg}3{{%yS+-85|w#4Yir0vVVGVI6O+e%U)HCNB!|=aB|F! z0m-JfBkRfM-p|YKrJSA}oERrW<I`I~5BE)SHusD4lAYYATZRGC!@a}d=)jbyT@2pe zbMQSqFaws6m-k?7{;}mE((T37wtIlv#czMjb>yxy9gjFcimE9!X)0f8sv;YzF9qrq z$xBHVrKgg${BrVY{-t`0Xya&)ug2ROT(|qlW&S-`O`hfZ(n$Ww2l7JFF04r_d9AQ2 S50iHbAI>?eq9i{S8h-;kjN#7! delta 949 zcmZuvO^Xvj5Pj8?Y_ju_k4a`HlgWOItSATyISUF22p)EE7X}o;M%i%{mdqhS4?(gb z>qVBn`3t;w@?ZD|ym;^sFMHLCC#{}M4x%&NGq1Z}RlTa}H=UnBZ2j=;?QiG%k1>A) z|8Maey*rC9O;QU<2@;TmqcYMZB$dUoie$OK0!KEUJ4hVRoa!R+7Q7mg`T{qQG#A)M z(t@;b0wiq{U~i|yfe(N4b8m+jW@n$hTFk)ad7I~G|LpshiDEe}qTGc+9wQG~4QU-| z6UxJsFtW+)`12*bU}r6)?FFx44rC2eVc2;MhZZEOETb<QkWK#ikS+A{Iutt!pz4N6 z%$rD7Yly)Gapi)@W+8IUu8iEr_yc8p&@o1L9=xZ#RR{>794HoULq<?7s3zJWe`ME; zX$5)PXi-IoND1@}Y$8$?M<l2cCjFe{5Yi4*6{>=#0y2gQI2PqAP@FY~F0+ks&SH3H zFdVmxA`_~IcA<*dAu<;!E13rgi{=2Y1)c~*rZ$}gg?$<w;oiR-cEDmHJQEm2V&033 z8m}Vl7Kw&Mu?6xjvKVO(iudPfhWeLvZNdf}M5eTgVZ@O8V5+|g`L`Rn1|C!H*?(Ge z90?u=$0u}PjF|_%m+OTNZt4aCPHQkU6<jaQG9cZ+blK3rw{TN`EWOqD%VYhuyic3@ z=F&{xsLUwVVfD1KTjh(myZmWZ?a{hxF6Bf1fLsxfPXR^bi#lBvTV#{IYIpU(zFrHM zky(cPa<sdqPwhw4)93aUHMHl%w5_i=o8;@e&IS$jk@Kq4X0A&v)%kU)!LxY9AMzWg ATmS$7 diff --git a/classifiers/attribute_classifiers/train_atr_classifier.py b/classifiers/attribute_classifiers/train_atr_classifier.py index 5145bae..0c59564 100644 --- a/classifiers/attribute_classifiers/train_atr_classifier.py +++ b/classifiers/attribute_classifiers/train_atr_classifier.py @@ -1,5 +1,7 @@ import sys +import json import os +import time import matplotlib.pyplot as plt import matplotlib.image as mpimg import numpy as np @@ -14,15 +16,30 @@ def train(train_params): x, y, keep_prob = graph_creator.placeholder_inputs() _ = graph_creator.obj_comp_graph(x, 1.0) obj_feat = tf.get_collection('obj_feat', scope='obj/conv2') + + # Session Saver + obj_saver = tf.train.Saver() + + # Restore obj network parameters + obj_saver.restore(sess, train_params['obj_model_name'] + '-' + str(train_params['obj_global_step'])) + y_pred = graph_creator.atr_comp_graph(x, keep_prob, obj_feat[0]) cross_entropy = graph_creator.loss(y, y_pred) vars_to_opt = tf.get_collection(tf.GraphKeys.VARIABLES, scope='atr') + vars_to_restore = tf.get_collection(tf.GraphKeys.VARIABLES, scope='obj') + train_step = tf.train.AdamOptimizer(train_params['adam_lr']).minimize(cross_entropy, var_list=vars_to_opt) + + all_vars = tf.get_collection(tf.GraphKeys.VARIABLES) + vars_to_init = [var for var in all_vars if var not in vars_to_restore] print('Variables that are being optimized: ' + ' '.join([var.name for var in vars_to_opt])) - print('All variables: ' + ' '.join([var.name for var in all_vars])) - train_step = tf.train.AdamOptimizer(train_params['adam_lr']).minimize(cross_entropy, var_list=vars_to_opt) + print('Variables that will be initialized: ' + ' '.join([var.name for var in vars_to_init])) accuracy = graph_creator.evaluation(y, y_pred) + # Session saver for atr variables + atr_saver = tf.train.Saver(vars_to_opt) + obj_atr_saver = tf.train.Saver(all_vars) + outdir = train_params['out_dir'] if not os.path.exists(outdir): os.mkdir(outdir) @@ -31,23 +48,32 @@ def train(train_params): img_width = 75 img_height = 75 train_json_filename = train_params['train_json'] + with open(train_json_filename, 'r') as json_file: + raw_json_data = json.load(json_file) + train_json_data = dict() + for entry in raw_json_data: + if entry['image_id'] not in train_json_data: + train_json_data[entry['image_id']]=entry['config'] + + image_dir = train_params['image_dir'] - mean_image = atr_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width) + if train_params['mean_image']=='': + print('Computing mean image') + mean_image = atr_data_loader.mean_image(train_json_data, image_dir, 1000, 100, img_height, img_width) + else: + print('Loading mean image') + mean_image = np.load(train_params['mean_image']) np.save(os.path.join(outdir, 'mean_image.npy'), mean_image) # Val Data - val_batch = atr_data_loader.atr_mini_batch_loader(train_json_filename, image_dir, mean_image, 9501, 499, img_height, img_width) + print('Loading validation data') + val_batch = atr_data_loader.atr_mini_batch_loader(train_json_data, image_dir, mean_image, 9501, 499, img_height, img_width) feed_dict_val={x: val_batch[0], y: val_batch[1], keep_prob: 1.0} - # Session Saver - saver = tf.train.Saver() - + # Start Training - sess.run(tf.initialize_all_variables()) - - # Restore obj network parameters - saver.restore(sess, train_params['obj_model_name'] + '-' + str(train_params['obj_global_step'])) - + print('Initializing variables') + sess.run(tf.initialize_variables(vars_to_init)) batch_size = 10 max_epoch = 2 @@ -57,18 +83,23 @@ def train(train_params): train_acc_array_epoch = np.zeros([max_epoch]) for epoch in range(max_epoch): for i in range(max_iter): - train_batch = atr_data_loader.atr_mini_batch_loader(train_json_filename, image_dir, mean_image, 1+i*batch_size, batch_size, img_height, img_width) + print(i) + # start = time.time() + train_batch = atr_data_loader.atr_mini_batch_loader(train_json_data, image_dir, mean_image, 1+i*batch_size, batch_size, img_height, img_width) + # end = time.time() + # print(end-start) feed_dict_train={x: train_batch[0], y: train_batch[1], keep_prob: 0.5} _, current_train_batch_acc = sess.run([train_step, accuracy], feed_dict=feed_dict_train) train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] + current_train_batch_acc - val_acc_array_iter[i+epoch*max_iter] = accuracy.eval(feed_dict_val) - plotter.plot_accuracy(np.arange(0,i+1+epoch*max_iter)+1, val_acc_array_iter[0:i+1+epoch*max_iter], xlim=[1, max_epoch*max_iter], ylim=[0, 1.], savePath=os.path.join(outdir,'valAcc_vs_iter.pdf')) - print('Step: {} Val Accuracy: {}'.format(i+1+epoch*max_iter, val_acc_array_iter[i+epoch*max_iter])) + # val_acc_array_iter[i+epoch*max_iter] = accuracy.eval(feed_dict_val) + # plotter.plot_accuracy(np.arange(0,i+1+epoch*max_iter)+1, val_acc_array_iter[0:i+1+epoch*max_iter], xlim=[1, max_epoch*max_iter], ylim=[0, 1.], savePath=os.path.join(outdir,'valAcc_vs_iter.pdf')) + # print('Step: {} Val Accuracy: {}'.format(i+1+epoch*max_iter, val_acc_array_iter[i+epoch*max_iter])) train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] / max_iter val_acc_array_epoch[epoch] = val_acc_array_iter[i+epoch*max_iter] plotter.plot_accuracies(xdata=np.arange(0,epoch+1)+1, ydata_train=train_acc_array_epoch[0:epoch+1], ydata_val=val_acc_array_epoch[0:epoch+1], xlim=[1, max_epoch], ylim=[0, 1.], savePath=os.path.join(outdir,'acc_vs_epoch.pdf')) - save_path = saver.save(sess, os.path.join(outdir,'atr_classifier'), global_step=epoch) + _ = atr_saver.save(sess, os.path.join(outdir,'atr_classifier'), global_step=epoch) + _ = obj_atr_saver.save(sess, os.path.join(outdir,'obj_atr_classifier'), global_step=epoch) sess.close() diff --git a/classifiers/attribute_classifiers/train_atr_classifier.pyc b/classifiers/attribute_classifiers/train_atr_classifier.pyc index 78a4e66e7a05fc26d316c57a52f11a19907ab03e..500169dccf17b7b3efcb303cb4ea2258248b13b3 100644 GIT binary patch literal 4071 zcmcgvOK%*<5w70lF2$EfF5eO*E=5tItXNE(KnVgNwk#X6;74K#k&=a-!EmR!?A7ed zXr`Cq8Z#&5luNF;<s1Y#BnSVV9CHYe0R92-SJg9I($+bA^j23@*YoRYvi{B7^xxnA z_p>gQUp0Kbhu`dPSOWYg3W+*rw4#ox_%-U(D!NXcdPPrAzd@a5wVkB?CF)F7+iB{| zRP-!$<|_I!b>=I2fjWy7-J;GC1vUCT!dND~LP4Eg)F}N6>1BFNq#}v(>D~k)Hvhxd zXI%yIE&MLx_Yr=x4;w_5sIT}$p~%*#UuXJ@Iu$i4uucebf^y^pH82};q`}VNGnk;) zN)=6EHmN@;rY0$u{BlJ(;%LxD(4;jxCA&%5c&%s+wiG*@X4^SDpfr~tg40XlbcXUe z<(I_iROytXouxj?1J_gHdiqVGX%Wp-QC$|tGvc^EPdUmyOSw~;{qA`LSIkk5D4(Mb zzFemQr|~mGFD5D7QlNiVqnBH%Kw}WdW$G`8_B6ETOYQ!mD1E9p)5jkLP(9O||GwU^ z)Z==z&ZWIjC4ho*v{Wpf<ID*M`Y7=NZ4^snm#Du?{T0fWSeE@qF=w^h3MLk*Sf&6g zlyI(vbE{<6C|{)jhidf4UsJJ0cAfegl&^`4b+T8auN!pwW6IZ&z+Eb?P=1AiW!lI$ zq%`+yykC)hDS7bnehoj*pX1Ja{u}nU3V&C{^BO?0E|~pmE!d#PAK9za-y|5zuTlU@ zP(zGBbCZfqiq<LEWUXrw7)-$cS7eizb-l`cfs+lcaVGgDuPV4+8s937gCA^B(5~!H zQgMw8uy#!l2sqwnavPb{={{f&blxwnQ*Z-@0O@{Pny^jzJeLPQD%ha{CT~*FreK$f zZ3*fI<xpzVa!<618^YaTR>|4<9_J>V{z=)JtOLUzeEA1e?1Dwi+AdjC+@k!}8?u5o zNVaHs4NF-)WOrM#yFJbhWR)HCIlDb_iRHK#fHUs)ZOZrHPkkYGzPLj<BG^Su?$DYe zAH4iW{dvXRrj?SLBW+mS<#?CZHge><OAqWW^>0zICkaa^;35j11N;$)15ALDy&b$o zIC`FOmle5oIPn9PpfyateX;lVxa+T}sq#g}1@bDdg@eXEi-RAspVu{2G$?PqN$p49 zReSkoB#Y*wC>XCii0%jlXg}S12jlcJj12zz@4NOAKH|f^ccgu5F=Q$0#);0R9DBbT zM<?&t%casF4s_^6{y<xv&~MxMQqT+ILqGH~s}lnj**X^Art%5%c9QpjKk&j-T%P!; z7n;mwHz9cDr^Y`Fb=J1WzHR%dZXaqB_1bY_&A@!6gP*s>Vf#($uS^(1zilF8jUSfk z>@Xf$FEFY37_^<{CG}0@^|Lq<b1oD!@O#=bfo(zq`5u`b#y_4S;p4m<9cVxDgoS#d zO3PRt#e-yMIZ&3iMQZ+q)D0?4SqsZ2u^(Ji$heqeYxkmLl&zi9JV6S9Z%rJv@!^|m zyRi2S%5Avh*(#PNRi<nim2|~3oT1_O-+YZ2$`=3gT}kFh>j1Urx*nghWjh1a5mTPj zMfBVrG1v)6PeU^hV#EYPCG$`8uY7x)Enwq!yWUCW=_KwR-%Elc!3nzVbwfYP%#qQl zI5~$7%SB<lT~yoS#~x1#mLpMl5?P)4U29JCQw>|_*c_z2)K8ARZVIr*sVF9)-_^%) z80ge95!#WRp|y+lcydBmJ#D>i9EKWJVUDrx@@DEO@=67zsm@?q3vx3cnA;`6j;(QN z!Z^#MVvjkbkM+pyo;^Hx^6>pnK71<SJ_O=_Sz_3MZ2*)qqNS?+*hE6ejpFucRDLMU zmtw@3h!YfC(5X-P%1G#g7mQ4f1lQGB5nQ|2{6Iv~aK59QfOI+p9(8)a-4hQ{B8uEr zyHYB0jNYMdyT_hWms+I5$TH~1*UUSEo<yVvi5&^`l}-^kmo@bxl%4Nl62=xSm5IT$ z2k;EhVS{q>#q*siAL^pE8iNqWH3%=zNB%Ihu9Xfcvk<z=oZ#TVl_>*I<X)f3I+FIh z&^qywG(HrbcYuQ=0tO+cor?PrbaWt1>mqo7I<je;j0C^PGxKchVH`>P)xLttc(m{U zlzZZbA(tnU(t)rXadnPM=(Jo7T%_D17Z4-8x+CKdG0nK;ZZ!40SD0AGGMdbw=N)U) zJGN4^@sRV5OsEB0e2C#89LX~EU&$a8SV&gq&dGS-R&W(Y2aF&G8MXt6I&_Ggv%!LN zLkY(rCu5Ma`oKSR6GAfPonajC9AHa4L;N)LM~=;1zsQ#9BG_~Rjt@AAgvk>~8f>It zPE!Jr)W$)I1b;&UWy3T@dT#cWaEI+Zh)9;~x1jcq<AL6{el+k$`;YMa?LXGhv(Fz^ zALB`v?c)JW&Ee4MbMr6TFQdLN74IdZgB!5Ujn445CSK{;+KOtamHK@BR&4{+RebN( zSJbZB!*{FRQmxuF_zU%>YAUAY)s&i33u>!AjkTe+YYnx89VoRk(NL>3)@-U(=(ZqP zQjOZ3`VOS4jTX3>x~=Zwtn|h62K+R7d;Gq***D4#euB-xK67}i;4Rjc|05gtmalMV z4v~mlY`(!I$uEBrTFWEmoT7U=xmjaskN~<eC!*m1CHWbo?w!_A7N<vH{OVwnc?_1b z5oThbJxnmge#)|SUXI*wa4!R<d!=!mh5XLR$SIG#hrN`)%$#osG6&Z+{{PT@@%#XP bP7Oo-3whVC!27In-}C5gMbFlzyjk@v*e<|h literal 3616 zcmcgtOK&5`5w0dBN}{O8QV)vKYL~KhH`*9p0Xb|CV-ZKT1KDuk#1fz&gzZ6#(<FyG zubpnmqBC<+{y+`^@*9$4fc%8~iCl6FkV^pnfaI$hQV%<TZ{ee-y1JfURrOTAS*ri( z;lIv$H2GBU{2BiC4_JKs6oo|H8(L9UP5BiXRjE5O-OkdeM%}sTcAmNmQ@T#w#VOsO z?$VTArtZp=UZw5>@+<T^xNDMGBfm<oE0q17Op~sOR4mb)9?T%J;6M0$+EXC+@n6ON z1N`mZaOj9eiZu#FPGSW8d6kL^6<BA4IYW7cle-3UR#ax$IehvvbgfiT6R{eL<q)os zKTGXv=+DW)9OYHY=R|*gqR)=!X@sm{r6yJuzI}A9sHl?<lptBB0a3m{AAYey#Uka4 zG+LxDEA-|sP^Q-LM?P%b$ko3q_xFh$dv4HZiE?06N1TdA%aqTP5336jNkhV3xf6CN z-Jo`{Om3A%4`|e+e3?z!GsT>>$yT6Wq5`$Tx+-3lq4z+j6;Mrb>onS+9IPh!tK@@# z;0j&Ws8}PnNu$=(OP$U7C@P>`<5jWp)<hXBAHVgf{u~wS6aegXDFN#7lF1EfqXI8c zLX_uau^~VMgSJ$DlX5_?EzBL}ASsn2h=+W@C~Z>PZ&Azb%5F<AYjG@on~E0s+f;0k zze7b^f`3GLoANDMAE0MYLzsrO$IO~=cD}=TNLQVT>#z(|KK!CtDR$}VSTPG>GV3R* zcqGu>%^onqU5nORSSBa}g~x)z<69I!PEkNwi~<Zx+IfTWB^scgkzNO`)aTOY#UAAl z-34}g)GEnWkQ{-(!4c)#;u~6RxPRr|qS3zO^@QCbEO=}|Uygy6Y`6c%*`6=5NS#GF z5(ymTV!Cs`752CZt}9jL!{5_yySX1z8=}V_z4>$H*V9bev@y-(Ui9sJhf|d#o{@hy zk6j6T62Je1LE-dD_MVRrcJ48zxqMF^&m<43NvEN7xPQvIDm(rMeC!{vpX-V$-l8{u zL>4F*4sgHm5FZGId-wF<IqtJhaX;|q-|xA3Jd*Q~*Eia^S!`@CNsV17*@wL(zWh;u z4Y?{7?K1lW&oi9``aCqY<A&OGbY?o|CWr@}By~X)d~N)9JC1MBj$Ofe6o#G4sR~Z! z*p#m$FU%wWoOxkjog5?=&hvvT_yj_(&P%3)*c;g-b{tj^>46c6e<Gv+Ns(Q~5*ZzP za`+-n$HDGX9Dais0*=G61m_tgz6rfpM@Hg%$8mgoFbI=#9eUQ8G+?g*zp{eA!o>uv zf4nE|`^Nb2-g6G`1iA`?NQf~L45rnW=2zMcZ4;YI9Uk?1-lg>dXR;Tm-?!UfpTfzz zo!=In&Y2E7khsWnZ_FGi!IiI_c8d@kGvx^Y;=J6$s^wRcE=`*>=?&SoB#9Dvy--^l z^aGQLl{=|P`SJDwB7MgZejFn(-E+aEIW<rM$V*@kGMx^+UWTqnG7(He-7~`^^iAdk zaeCqG4W_*$N<GmK*1$NgmxQ5#RhVPsC8%2|MlMXOm7I>*)hA|L?#_-*K0bQ?7e7B0 zR7Xh=3n2yL_AIgB5RcG6znrvN*+x+iOY=$B7@%~aB?3-2v0^7hWn2w=bA^7h0(bNp z?IwM|VG815qLHWAxos<fEl$M_SJ`dA4y&+V?_9gy(3>c7q^A(cWNNaD*lmFWF&F3% z<KgKr^lm0X+CDQDP#aOOE|bVil)AA{Uz-dWu-Qz<0GJWv3`I2!6Nf5F;h6T&7}8Oa zsV+4*>3)($7;|p<yK(_sv*{jcV%!=i1cM^_rmrtTSC;l*AAHx~MdmSu-KYfBqH$!J zu4ElcSzjBIdTExN3(q@&B1t@BaE&j)_K=kIOzg5G9ZRbM6YIIe!#yI?G{MOVEI7P# zA@HP@GIz7Z?ZMS7FN%TpRp6uGGSAV#8=7D+bdurirQ-F2&~Sg9@c5qmkmF=^pj#GI zLcJM#lr0prxCP31%F<iZRTLwA<;hG+biCu?mmy-CJheK@^tfa*-1pg~E;m;h<gG$^ z!lf!HzS#o7t(25Nq`bG`%86@15<QUt>+(98^eb#j%$$tJ!*KvXlEwZ8)Zs9R%%Rh9 zq{oLJ;O;y;Hu2eKN7KjcO=AzySy^y?;mn=;#~w~{xOXePNXI98u+CSU{SphQMrA`a z)rM+g-NEw!->0fssbjxdt*M%-<NHWGuI{P*>PmI3`b@YrwF!x)TEq9b+E;s`pQ?8% zO`KTucD1gg4xShBYs5QPhr4ph#`feXZysR5T<}PQ!8wAHf#QLgko@XPLx;ESJwbWd zNH5${Q5r-8;l&pb(tH~#Wz(A2CRsmBzB-Z8^5uJT-vkMIGfYi(@|@-QvT*%!?#fFT z7wJ@jODQ#%5uT|17)*h5k}LN!as{_6=KrCh;`tTcuNR^DDR-vb#L`d=wSq4{<y%>J HzoGsM4fITj diff --git a/classifiers/object_classifiers/train_obj_classifier.py b/classifiers/object_classifiers/train_obj_classifier.py index 45d5830..debc76d 100644 --- a/classifiers/object_classifiers/train_obj_classifier.py +++ b/classifiers/object_classifiers/train_obj_classifier.py @@ -26,7 +26,12 @@ def train(train_params): img_height = 75 train_json_filename = train_params['train_json'] image_dir = train_params['image_dir'] - mean_image = shape_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width) + if train_params['mean_image']=='': + print('Computing mean image') + mean_image = atr_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width) + else: + print('Loading mean image') + mean_image = np.load(train_params['mean_image']) np.save(os.path.join(outdir, 'mean_image.npy'), mean_image) # Val Data diff --git a/classifiers/object_classifiers/train_obj_classifier.pyc b/classifiers/object_classifiers/train_obj_classifier.pyc index eac87e3278c2ff50ee06780b0e6328693ac40b63..b0dd5d511a461b073a33139526028efddda13b3c 100644 GIT binary patch delta 1088 zcmah{J8u&~5dL=0Ifpwtex09j{0Imk5CsTDyhOo6NfD(81!<76w8r>wY{xzz#hzTa zp|$QO(5Fd_XhKIzNk>6}XlN*zx!90INa%EPbKiV3^UdrY6y6u3#?QFrUHd+bev;+L z^z+eJ1Ik2EMe$N%_zP+g>l@w@+yrzCIu3OX%7V6`Elf=KOXwA#Y-T0lrZ6io5gcMD z4bOsW!_7cbZVVF%y~I{{4&3alE<q}ev09qfj>1o%V?ZaM5-4hkHFlHa8dy`*DS|aW z2@h~hxH(J!pHKleAA$u&NQY2q_-UvN{0x)>-wAaUp&jT9oVH+=pJjB%EipEOav|{~ zlE;2cddtkAoQF@-(l20tT`=SvL$`!q#M#M`!3SBrIv=C-cMKYz$Ds63w~j(DqAj6v zfo)tIm!XqfRuhi|7r2X)zw52Qt%TyLyc`PAUh?xWtv3A4{v%dbc@C0(uEdS#^k}WA zuqg>Yc~+$xd`uH*2(=8icDB!HQAMJzZ|n&fd^IkQNnjnSd_p7jA&oO_Y6Y5vko04? zt}4*EQ@~v3e>7G@i7<su@m-@e@(4Ho+_xIXRh=76r0^@)`yrn2vRC<f2);e%Lr(#Z zFRcb2Blotb0<ha^_8OhtX1hfkVC0Z_(BB=r@;beCjia?N+WBSx9m>WiO=h#-RHq4t z@r9TW7lM*`bUQ(erZ;S;rq^tAiP{=&M+oyIncl!-pW)CMHD9+v$L)7|!51@sSdLDG zh#Vd)y5f44s5ZT$9WoFxnH6bzsv<2d`X&uin8KnbDHcRsE{IHU-`WzzU}`<fRwyN$ z(rfib{o(U&|IKzHFzrWTHQ2Q4%PEd?^+C7qHD0v30}^_L{<pKiTl<}HCnkOa4u8-H delta 923 zcmah{J#W)c6g|%lNo(iZPMkDpK02TXi46jXj|o&Jgv!8#f~9ArvC<?h7V7Y<sLGC1 zEdBsYEG(H}>d4Z8i9di45(9#ndrbmuL1MA)JNI1QbI-kRt8XeP{aePk^5gSi>b>?w zKm2es4YKfaP-&<PWD+t3WkMO4B@t#3<lz?(81PN_MW|KC70fiKEPn+TTZ>|XB&CgI zeF|X?105;{X`-q$bU92@sG}=0PziLyX&wRKC*j+e0bv13z%RvMnUR@*kXeL9$Q(io zGLO)XbvaNrlm)vlm=#hHsjV=yK+1n14)$AGP-PC4+<sK-Xki8Wn}RV?z}PhrR&jQ+ zbo5ngUR=(x_Gb<f%yZNhx)u4ig1!b>jIQhFr)yAo{;euOQ43OM!**^2HTd;d-8#EN zDKf@crpem)S^7t=Ey)Ag7THdJ;zl;J$WM;RgOj@|U*a{*G+jsse&cMP)1s<G-MqVd zEFwE`VM+p<kd+e}E3v4BHc7)cvAkb!ThgTEQ@~Q^e>B!&iAV}n;9H~h^9Wae-&>vY za-F@KX(DW3_nmmiN7CZ+5PUuDIFi7FP9yrBxMd7J>A-*h%n#0`h1iImC68_v=>25d zop?Rz25xUScBQvR&43-x9rXq~Zr=+udX6SRg1PQfFAn_iU^IW3dM-qLetFdwhDBME z8TX^tMpHDRW8<E1qL%qstVai?B{I<q^Ny%RAI)|xN5p{Hli@h%ZF|Fsx3|@y_m&ko K=~sF;Eq(%qwX%Bv diff --git a/classifiers/tf_graph_creation_helper.pyc b/classifiers/tf_graph_creation_helper.pyc index ba8c1623b0ba07ff0da1f5e5eceedd6f69bfefaa..cc7087d41170050991a04ef6986c615efc6ba662 100644 GIT binary patch delta 16 XcmbQCGed`+`7<w9K=SL2?1EwdEs_Nq delta 16 XcmbQCGed`+`7<w<^4*sk*#*S_FDV6O diff --git a/classifiers/train_classifiers.py b/classifiers/train_classifiers.py index d2edabb..e7d0ba9 100644 --- a/classifiers/train_classifiers.py +++ b/classifiers/train_classifiers.py @@ -12,7 +12,7 @@ import attribute_classifiers.eval_atr_classifier as atr_evaluator workflow = { 'train_obj': False, 'eval_obj': False, - 'train_atr': True, + 'train_atr': False, 'eval_atr': True, } @@ -21,6 +21,8 @@ obj_classifier_train_params = { 'adam_lr': 0.001, 'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json', 'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images', + 'mean_image': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier/mean_image.npy', +# 'mean_image': '', } obj_classifier_eval_params = { @@ -39,11 +41,13 @@ atr_classifier_train_params = { 'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images', 'obj_model_name': obj_classifier_eval_params['model_name'], 'obj_global_step': 1, + 'mean_image': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier/mean_image.npy', +# 'mean_image': '', } atr_classifier_eval_params = { 'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier', - 'model_name': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/atr_classifier', + 'model_name': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/obj_atr_classifier', 'global_step': 1, 'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json', 'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images', -- GitLab