From 99ac4c72adfe64ae6c4f84d87c78946ae232f2c8 Mon Sep 17 00:00:00 2001 From: tgupta6 <tgupta6@illinois.edu> Date: Thu, 17 Mar 2016 16:10:28 -0500 Subject: [PATCH] Modified code with variable name scoping --- classifiers/attribute_classifiers/__init__.py | 0 .../attribute_classifiers/__init__.pyc | Bin 0 -> 164 bytes .../atr_data_io_helper.py | 83 ++++++++++++++++++ .../atr_data_io_helper.pyc | Bin 0 -> 4641 bytes .../eval_atr_classifier.py | 66 ++++++++++++++ .../eval_atr_classifier.pyc | Bin 0 -> 2804 bytes .../train_atr_classifier.py | 64 +++++++------- .../train_atr_classifier.pyc | Bin 0 -> 3616 bytes .../object_classifiers/eval_obj_classifier.py | 36 +++++--- .../eval_obj_classifier.pyc | Bin 2671 -> 2699 bytes .../object_classifiers/obj_data_io_helper.py | 13 ++- .../object_classifiers/obj_data_io_helper.pyc | Bin 4653 -> 4614 bytes .../train_obj_classifier.py | 7 +- .../train_obj_classifier.pyc | Bin 3113 -> 3118 bytes classifiers/train_classifiers.py | 38 ++++++-- 15 files changed, 246 insertions(+), 61 deletions(-) create mode 100644 classifiers/attribute_classifiers/__init__.py create mode 100644 classifiers/attribute_classifiers/__init__.pyc create mode 100644 classifiers/attribute_classifiers/atr_data_io_helper.py create mode 100644 classifiers/attribute_classifiers/atr_data_io_helper.pyc create mode 100644 classifiers/attribute_classifiers/eval_atr_classifier.py create mode 100644 classifiers/attribute_classifiers/eval_atr_classifier.pyc create mode 100644 classifiers/attribute_classifiers/train_atr_classifier.pyc diff --git a/classifiers/attribute_classifiers/__init__.py b/classifiers/attribute_classifiers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/classifiers/attribute_classifiers/__init__.pyc b/classifiers/attribute_classifiers/__init__.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae52301a780f66bda35035e28368b85b5b73608d GIT binary patch literal 164 zcmZSn%*&;7_ElIi0~9a<X$K%KW&si@3=F{<AQ3+eAi;n}6#D|j^fU5vQ}s&{^Kug_ z^_}xmQuW<a^TGlhVN`NXVsUY1T4ridv3_DnNl|7}X-R54vS@sKW?p7Ve7s&kWeEq+ TM4R0Fl+v73JCMD_K+FID(HAGQ literal 0 HcmV?d00001 diff --git a/classifiers/attribute_classifiers/atr_data_io_helper.py b/classifiers/attribute_classifiers/atr_data_io_helper.py new file mode 100644 index 0000000..0f707a9 --- /dev/null +++ b/classifiers/attribute_classifiers/atr_data_io_helper.py @@ -0,0 +1,83 @@ +import json +import sys +import os +import matplotlib.pyplot as plt +import matplotlib.image as mpimg +import numpy as np +import tensorflow as tf +from scipy import misc + +def atr_mini_batch_loader(json_filename, image_dir, mean_image, start_index, batch_size, img_height=100, img_width=100, channels=3): + + with open(json_filename, 'r') as json_file: + json_data = json.load(json_file) + + atr_images = np.empty(shape=[9*batch_size, img_height/3, img_width/3, channels]) + atr_labels = np.zeros(shape=[9*batch_size, 4]) + + for i in range(start_index, start_index + batch_size): + image_name = os.path.join(image_dir, str(i) + '.jpg') + image = misc.imresize(mpimg.imread(image_name),(img_height, img_width), interp='nearest') + crop_shape = np.array([image.shape[0], image.shape[1]])/3 + selected_anno = [q for q in json_data if q['image_id']==i] + grid_config = selected_anno[0]['config'] + + counter = 0; + for grid_row in range(0,3): + for grid_col in range(0,3): + start_row = grid_row*crop_shape[0] + start_col = grid_col*crop_shape[1] + cropped_image = image[start_row:start_row+crop_shape[0], start_col:start_col+crop_shape[1], :] + if np.ndim(mean_image)==0: + atr_images[9*(i-start_index)+counter,:,:,:] = cropped_image/254.0 + else: + atr_images[9*(i-start_index)+counter,:,:,:] = (cropped_image / 254.0) - mean_image + atr_labels[9*(i-start_index)+counter, grid_config[6*grid_row+2*grid_col+1]] = 1 + counter = counter + 1 + + return (atr_images, atr_labels) + +def mean_image_batch(json_filename, image_dir, start_index, batch_size, img_height=100, img_width=100, channels=3): + batch = atr_mini_batch_loader(json_filename, image_dir, np.empty([]), start_index, batch_size, img_height, img_width, channels) + mean_image = np.mean(batch[0], 0) + return mean_image + +def mean_image(json_filename, image_dir, num_images, batch_size, img_height=100, img_width=100, channels=3): + max_iter = np.floor(num_images/batch_size) + mean_image = np.zeros([img_height/3, img_width/3, channels]) + for i in range(max_iter.astype(np.int16)): + mean_image = mean_image + mean_image_batch(json_filename, image_dir, 1+i*batch_size, batch_size, img_height, img_width, channels) + + mean_image = mean_image/max_iter + return mean_image + + +class html_atr_table_writer(): + def __init__(self, filename): + self.filename = filename + self.html_file = open(self.filename, 'w') + self.html_file.write("""<!DOCTYPE html>\n<html>\n<body>\n<table border="1" style="width:100%"> \n""") + + def add_element(self, col_dict): + self.html_file.write(' <tr>\n') + for key in range(len(col_dict)): + self.html_file.write(""" <td>{}</td>\n""".format(col_dict[key])) + self.html_file.write(' </tr>\n') + + def image_tag(self, image_path, height, width): + return """<img src="{}" alt="IMAGE NOT FOUND!" height={} width={}>""".format(image_path,height,width) + + def close_file(self): + self.html_file.write('</table>\n</body>\n</html>') + self.html_file.close() + + + + +if __name__=="__main__": + + html_writer = html_atr_table_writer('/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier_v_1/trial.html') + col_dict={0: 'sam', 1: html_writer.image_tag('something.png',25,25)} + html_writer.add_element(col_dict) + html_writer.close_file() + diff --git a/classifiers/attribute_classifiers/atr_data_io_helper.pyc b/classifiers/attribute_classifiers/atr_data_io_helper.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b958afac0eda1f2f3ff3b1e2e2dc36c210e170a5 GIT binary patch literal 4641 zcmcgv-E$*H5${>qvLwqNXUkvV5_G;wAyQ!Zptu7fn>h9*A%KtUg_x)+*;;E?(klBQ z?~bvZ+kK#Xyzl`35`Pb`RPl0Ez^{8Gn*?~^A(nP#dwOT5dw%^fy`{e`)M|hK>nA;$ z{AO`~2Q8~BA`h3Mkf?h@E9$B#-=eNHrAsuLq3-N-Sf<e&bt}_hmAbVlJx|@b&=!RY z6fIJBk(5^=zGldqrI#h@F3Hev%EDRxH_n`J?vV!p=&SUS=$h!$?h5&<<kjd4i_(hn zSIC>EYpC3y0GdMV0zLK{G+HBXkvs@lqH73QrpMiN;s+b#uk#y<2b;X8iUow{tregW z{1d<97N8+LgDwiPUI3IRD52|B&~2g3qaC7Uf3z@>D3>L!E&B616&2EoMi%827F1N} z1PgrhTP+BBP_sn&3^j6KIYVoC+1%!6Favc|Q6rv-hc`0;*}#%#^K`jv=`!V5LRQ&; z%%jg<siIDLjz$$4RVi1bYn0bnx>qKiQY?_3rx7MC&}fnJ1%3tvAiNJTa}4ygrHVz7 zy+rvUjbM)|ZB3{eZLtbi5Gof-G{By#kbIwtWy&#i9#iXLmRCIf+@iBrRNSLn%(n&8 zY!$Fpm5Nom{6Jw3ps|?f|3noHy4?91_jOgQQQn{&W`rh<Vx4j@FoHIXVuQ8J!NX-{ zZ2SSBgN@5fVg|6s1)^Vq3Y)S)sAV=&6^;L)@=f-fwH!VSSnCY#W+PvvjhnasljzS- zC+^VnRf{(N@7J-BeE#bqbHId+HK+mS#fBRjn6|M6J9VHD(}3{oM4g-+p#7dsAQr3O z`SaE?eDWCHN0e#~^la#k{h$g;{5P0k9(OyV@c=deKh}OahVKx?zMJ})4nR^M5y0*q zMDD<M0#9c5lDHoXV0`@klDreh3J3v`;E{+)#(pfsC`$|xCa$M>Qal!={AjGNg!<A? zlT4`8jR(F=0)rJCyLxDrNuVMJU#3$VH_}Ecs2Y6f%gfO?h$g#<QXj&2>u#F5S2l~4 zWnwRgS{%BF03z1S=?9@7yOFO2AK9iCqzD~^z3;}3>>OdkqTN(GLG1aL^7z!%y`dxW zc{ZdwLq8Y{^<=4ma{<7HqG@mF#<3q}^5{nH?Z|VrD~q~1H6pW{Vd$Rnn{10fn8p~Q zH1E-9$SU;GWbBBwW%iluGV?>fr+v?X28nEBkOrP(j3_4WCFg8^QB<C#$;IuUmxQ9H z**-sFq|At)F_J8J3=yJA3k;GjI|0=mCXwIPZXCH+?Zd?L+wc4FCm$V5Z@tjXvY;RM zY1VeNPJ`2P?K@xPrVaz-1PKrh$9}pqzS3)0o6(AbIB<+D*vI^ok;uM>j;s~6X1%5E zTd%2atMB0dUA1VH)jIx*`2R?)skg0;+EwdTkA0sv#!Zs?5eSpiIh9H-hs$%uM1GO9 zYf|TRg^D?u048#%qAfH|w-rwQ&s*$$#{oKOu}|B(+_922dzJVlds>Sn+Z)_%a`zf{ z-{S5*cdv8z20HO;*{{u>d`D<=PPvT-gEx!LIQU7u1xN)#-WNB4@J}EB;b@B48qi~` zzygM-0u-_>&N#l6X;5Nxf!gJVoS%UuY7L${%Y0l%pY8HN3Q+m*!`Xv_2<w1TLVo;; zP3DjvD|Gft<V9Fv4yV<`7@R|nO&a8YYO&eH7y@fQOp?^DGmqU`vYnghtFcKj$U%=D z+dt&@#C|u%lCX@=qe<xgSF=^{TI60j0YXhIXFueH(7yR#5qaN-q>T4QYE?B65j$2* zNf-@v6gnJB+C2?@=OSetdpwhWjMKYl5>q~^M5zp4g21_|xG3wQxUz(RO-n+|lKouV zy~pm)jgdd(WqyDTrh@?;KrF;)S~)QjCKU%6$*Qos@(5wX@0(NYB6|wD^Wf;&;m4o; z`e~EZ->-Ehx6{PC!brB#JWWz$guU%Y+fA4;^!K(UC;#-(?(Uo0`^}oDD7kKW4%qAw zlHvpvW{YbGC=$yOvB;`@`)y3U8v}}p4i29{JI;?Gb`p10q%c94(c!~D5dItlV~7j3 zlGss~PY5R~$CZpFoE}nPz%~;t4mh@qMO!A~)Y(Hs(9KE`pvD@U$&NElSxx*pI^C~j zoFmN0+ka7X+S&t?14Wo^p3`C-?GAT)+)0$N?l^=c+WJWvxmt{E?}INJ$1&r0K~GBs z@P&VMN0@VZc0JF*0UAL)RygA*Cbg`#)C1LH5T;FlUj+kr19$@?auPl~CPzTVcI%)_ zXl7||Z~H~D-E>2}xBc^99lZavdHn3-=6lb6bA0q*yJ^aZy%$AOycNX0e2mx!Siu}3 zf(>J9iUBE<I2BJS5QbUay(6gD5sW6<9S9U0@}zp4Jh<^t_zytbdX(|a$(X-xCc)0r zk25|8;YOU45Yg>PWVa>Ijf>oPlEi2)Ofugf#rWLO-q{KwS!&7mFp}b$y`;q^nZz#F zYmOr)h2untcaBW!*gwMr`-r;_xcd;@t=d9;lSP#M6UcjlCVpL3)!AyfTAC;vxk2nW z*}HiBKWn(Bmt)8Fv-41A?E@Sjhd0HX^E>AeLOyWA9k!P#-rY#*pp3&lLs_VYK|I(Q z$Af@N8Ts5WyXT=@L1*5(%IrlR+(U=-Fb)$P2B)Y3nUcIQAxoGr4z-qY<#wXDru`-! z+8m<dEI6*mSLWzv3u-^kk`x8Ng?xbMzD&Y*KPKVvq-^7}-yGhIoc#-QQc}xVVa`os uNJh=r^^Q#{#_JgW+ULm2n&|3^|4aDuec4)hV`X-wwz^)o>gu6dQ2zqHe(m)D literal 0 HcmV?d00001 diff --git a/classifiers/attribute_classifiers/eval_atr_classifier.py b/classifiers/attribute_classifiers/eval_atr_classifier.py new file mode 100644 index 0000000..c3222b1 --- /dev/null +++ b/classifiers/attribute_classifiers/eval_atr_classifier.py @@ -0,0 +1,66 @@ +import sys +import os +import matplotlib.pyplot as plt +import matplotlib.image as mpimg +import numpy as np +from scipy import misc +import tensorflow as tf +import atr_data_io_helper as atr_data_loader +import tf_graph_creation_helper as graph_creator +import plot_helper as plotter + +def eval(eval_params): + sess=tf.InteractiveSession() + + x, y, keep_prob = graph_creator.placeholder_inputs() + _ = graph_creator.obj_comp_graph(x, 1.0) + obj_feat = tf.get_collection('obj_feat', scope='obj/conv2') + y_pred = graph_creator.atr_comp_graph(x, keep_prob, obj_feat[0]) + accuracy = graph_creator.evaluation(y, y_pred) + + saver = tf.train.Saver() + saver.restore(sess, eval_params['model_name'] + '-' + str(eval_params['global_step'])) + + mean_image = np.load(os.path.join(eval_params['out_dir'], 'mean_image.npy')) + test_json_filename = eval_params['test_json'] + image_dir = eval_params['image_dir'] + html_dir = eval_params['html_dir'] + if not os.path.exists(html_dir): + os.mkdir(html_dir) + html_writer = atr_data_loader.html_atr_table_writer(os.path.join(html_dir, 'index.html')) + col_dict = { + 0: 'Grount Truth', + 1: 'Prediction', + 2: 'Image'} + html_writer.add_element(col_dict) + color_dict = { + 0: 'red', # blanks are treated as red + 1: 'green', + 2: 'blue'} + + batch_size = 100 + correct = 0 + for i in range(50): + test_batch = atr_data_loader.atr_mini_batch_loader(test_json_filename, image_dir, mean_image, 10000+i*batch_size, batch_size, 75, 75) + feed_dict_test={x: test_batch[0], y: test_batch[1], keep_prob: 1.0} + result = sess.run([accuracy, y_pred], feed_dict=feed_dict_test) + correct = correct + result[0]*batch_size + print(correct) + + for row in range(batch_size*9): + gt_id = np.argmax(test_batch[1][row,:]) + pred_id = np.argmax(result[1][row, :]) + if not gt_id==pred_id: + img_filename = os.path.join(html_dir,'{}_{}.png'.format(i,row)) + misc.imsave(img_filename, test_batch[0][row,:,:,:] + mean_image) + col_dict = { + 0: color_dict[gt_id], + 1: color_dict[pred_id], + 2: html_writer.image_tag('{}_{}.png'.format(i,row), 25, 25)} + html_writer.add_element(col_dict) + + html_writer.close_file() + print('Test Accuracy: {}'.format(correct / 5000)) + + sess.close() + tf.reset_default_graph() diff --git a/classifiers/attribute_classifiers/eval_atr_classifier.pyc b/classifiers/attribute_classifiers/eval_atr_classifier.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0be14f65263aa832ead7df9c9f5988c223c072a2 GIT binary patch literal 2804 zcmcgtTW=dx5T3J+oy2i&c2dVpFDVfa9!R7<<e@@9Riv~+svxaeWC5%7?y+;+y><6E zP3Z1Zio_f5`~rR!&paXV1K^uk+bxJ^y!P3dIdePT%pBMLTy5_E_WLuRrXL6I&+*HD z#uDJGC?x97X+?cClRGr7QGa2^)@j_J{^E>XqW-eb4uwsMR;a&1YU;5?{neS?8uiy_ z^ak}eXLOtT9SUmn2I03jk0%EU$jSVL&*K9H(lQ^J-1nK=#jlCq1N`!5HEa`&6<raP zimc=c`r8_n4wYCJM7ThO!_KcDtcyvV0!V5UEYOuwWkb{&tX4p|K|!5*SFm4{jYTSI zR4j`9lGrcNcv+$@Q?X3rCKZYb;MbsDut;6IBHR+3DC#1tvMm~~l3k<mI*m7EcNw}( z_6eO59$>ygy|OZGfoh3t71Z2vlZsW=&zp*UwWq9#fkcPKTkk{w_G>fmYh-X%RG?1B z!3G7Jc(p0$&~dp=1srWtvCe)K>!A>=w_WL7tMp*ir4og!Q`x4AUzNqi4%>Y3-7l)_ zNc=7pNFRAODd<trp<<J^x>$ni{JmPFa*I<bBt#EzA2Yd0Jpg=+3IYA&63HkU?@+PC zppQ_fT`IN_N|o2Bz?l&l^%3951(K787adNBBi*1aNebFsV%U+$G)7n2j_5~+hK|&x z+?lDsC#UfrDzHZ`J$m6#-gDlT5CGq#@pU<7SCB?~T_?Ls;~P{!K0V|+E_pH}{G1pr zAZWAd5e5dWv4g$<{!I2jqE=j?oBxg$+#-C@r2kdMmpupH=P%!==WGl;!d%{#bW!*{ z+6G1^4;km%vd43@S7{N3@DAqd@dyh(e|%<_@Ja?_cc?vU>tJ#}Np-$liI4mwKEG?g zkmccwl0b)U>_yu0%zlh%AR2|qzzf~n>eMcRPbSt4OqQ>)f%an8MBYdr#_6SHw^rxY z9p_0bob0i>WO8PsP;5Dvi35Fc$fCRj_9ROtvE6%`P3)N@^Hrt;<6D!&My?ai{t<^V zd<$A!Ap{LM)F{(BmiuN9PP9Rd$lT>+<3lW8vO1?W)H*<T_+(B%cwM@$%fmDtnI4Na zu{^~A_U`-sB=h{sdwZ`-^HYaNfW<upS3}#z>rreqG^{z-PjsFmO}h&5DD%=Y*Ux}= zl8Is(dcHnO!a!%PiPMSA(TP=6{3J?Uu@KQnTh~v*P|K-B<k<|OxfbK5&%JQsF)ng; zn|Y?9o_OavlM-b(LXv60D7TrU9j5|Nn0SHkNiKZq0qzC39Vc+pVE-2;M~)14^nxq5 z4vO;#JnOk^k0{%aNi})HdV^5AuQCHfq-DGyaCN976h-`GUW~J`4XZ~cHtxW){@Ki% z)DqR_@Lo2GybB?RNfvqb6hac-L^+Q5PG4DXB(36yNv_?Y3AKb17Pl%lNDfQ`J@h7_ zt&l$8e<I0}fpk+Z^P&oh9A%WgzmT<R=NDS1Zki<nk-4YRik#;qq-Z}LFmKu{5}Z@a zs%e@JhxEaGFeFhPvQsubm7(E56FjFV@(x--y0toD5+*1CnpfBgzTEtvrQm*&W#}5o z(TIXX6D{eZdF#8pA!q^{fKts-@=9_aS=R*Okk7_6Po{}R^OO(3lPvi+ppMRxNFP}* zj=amG2e>_sPIUb2>-)2pA9@&#q0w1>gbOG$gNfDdyHY8FbopHM5H|^TByVy2{Cg~< zTF$m=t8LX&ZD(D5;Jl}r^0d{4x}|Dr5l>UC&M2p$8mgl{Qu}Hbc6Fx<YG1X~9d!>n zD|p^lJ5Cecn(Baivl^Gn{PY$xd=;Sk(lD{186YR71h^^5O`8J1jejc2jVEZnrJ@uV ziY%T)>E-DS4p8031}wPh{<S*JlWZ6!uTI&kw9S0-OyZvDFxA=VKJV@DGOT*Kx^HkJ z%ry9dk{;)zr^2ePd<eLB2!*$`{2vOsgn*ym|78;DPkDy&Ei5gRV;xTmC1e_p^Ec}` BcfJ4s literal 0 HcmV?d00001 diff --git a/classifiers/attribute_classifiers/train_atr_classifier.py b/classifiers/attribute_classifiers/train_atr_classifier.py index dd490ad..5145bae 100644 --- a/classifiers/attribute_classifiers/train_atr_classifier.py +++ b/classifiers/attribute_classifiers/train_atr_classifier.py @@ -5,36 +5,33 @@ import matplotlib.image as mpimg import numpy as np import tensorflow as tf import atr_data_io_helper as atr_data_loader +import tf_graph_creation_helper as graph_creator +import plot_helper as plotter -def train(): - # Start session +def train(train_params): sess = tf.InteractiveSession() - x, y, keep_prob = placeholder_inputs() - y_pred = comp_graph_v_2(x, y, keep_prob) - - # Specify loss - cross_entropy = -tf.reduce_sum(y*tf.log(y_pred)) - - # Specify training method - train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy) - - # Evaluator - accuracy = evaluation(y, y_pred) + x, y, keep_prob = graph_creator.placeholder_inputs() + _ = graph_creator.obj_comp_graph(x, 1.0) + obj_feat = tf.get_collection('obj_feat', scope='obj/conv2') + y_pred = graph_creator.atr_comp_graph(x, keep_prob, obj_feat[0]) + cross_entropy = graph_creator.loss(y, y_pred) + vars_to_opt = tf.get_collection(tf.GraphKeys.VARIABLES, scope='atr') + all_vars = tf.get_collection(tf.GraphKeys.VARIABLES) + print('Variables that are being optimized: ' + ' '.join([var.name for var in vars_to_opt])) + print('All variables: ' + ' '.join([var.name for var in all_vars])) + train_step = tf.train.AdamOptimizer(train_params['adam_lr']).minimize(cross_entropy, var_list=vars_to_opt) + accuracy = graph_creator.evaluation(y, y_pred) - # Merge summaries and write them to ~/Code/Tensorflow_Exp/logDir - merged = tf.merge_all_summaries() - - # Output dir - outdir = '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier_v_1/' + outdir = train_params['out_dir'] if not os.path.exists(outdir): os.mkdir(outdir) # Training Data img_width = 75 img_height = 75 - train_json_filename = '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json' - image_dir = '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images' + train_json_filename = train_params['train_json'] + image_dir = train_params['image_dir'] mean_image = atr_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width) np.save(os.path.join(outdir, 'mean_image.npy'), mean_image) @@ -44,12 +41,17 @@ def train(): # Session Saver saver = tf.train.Saver() - + # Start Training sess.run(tf.initialize_all_variables()) - batch_size = 100 - max_epoch = 10 - max_iter = 95 + + # Restore obj network parameters + saver.restore(sess, train_params['obj_model_name'] + '-' + str(train_params['obj_global_step'])) + + + batch_size = 10 + max_epoch = 2 + max_iter = 950 val_acc_array_iter = np.empty([max_iter*max_epoch]) val_acc_array_epoch = np.zeros([max_epoch]) train_acc_array_epoch = np.zeros([max_epoch]) @@ -57,24 +59,20 @@ def train(): for i in range(max_iter): train_batch = atr_data_loader.atr_mini_batch_loader(train_json_filename, image_dir, mean_image, 1+i*batch_size, batch_size, img_height, img_width) feed_dict_train={x: train_batch[0], y: train_batch[1], keep_prob: 0.5} - _, current_train_batch_acc = sess.run([train_step, accuracy], feed_dict=feed_dict_train) - train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] + current_train_batch_acc val_acc_array_iter[i+epoch*max_iter] = accuracy.eval(feed_dict_val) + plotter.plot_accuracy(np.arange(0,i+1+epoch*max_iter)+1, val_acc_array_iter[0:i+1+epoch*max_iter], xlim=[1, max_epoch*max_iter], ylim=[0, 1.], savePath=os.path.join(outdir,'valAcc_vs_iter.pdf')) print('Step: {} Val Accuracy: {}'.format(i+1+epoch*max_iter, val_acc_array_iter[i+epoch*max_iter])) - plot_accuracy(np.arange(0,i+1+epoch*max_iter)+1, val_acc_array_iter[0:i+1+epoch*max_iter], xlim=[1, max_epoch*max_iter], ylim=[0, 1.], savePath=os.path.join(outdir,'valAcc_vs_iter.pdf')) - - + train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] / max_iter val_acc_array_epoch[epoch] = val_acc_array_iter[i+epoch*max_iter] - - plot_accuracies(xdata=np.arange(0,epoch+1)+1, ydata_train=train_acc_array_epoch[0:epoch+1], ydata_val=val_acc_array_epoch[0:epoch+1], xlim=[1, max_epoch], ylim=[0, 1.], savePath=os.path.join(outdir,'acc_vs_epoch.pdf')) - - save_path = saver.save(sess, os.path.join(outdir,'obj_classifier_{}.ckpt'.format(epoch))) + plotter.plot_accuracies(xdata=np.arange(0,epoch+1)+1, ydata_train=train_acc_array_epoch[0:epoch+1], ydata_val=val_acc_array_epoch[0:epoch+1], xlim=[1, max_epoch], ylim=[0, 1.], savePath=os.path.join(outdir,'acc_vs_epoch.pdf')) + save_path = saver.save(sess, os.path.join(outdir,'atr_classifier'), global_step=epoch) sess.close() + tf.reset_default_graph() if __name__=='__main__': train() diff --git a/classifiers/attribute_classifiers/train_atr_classifier.pyc b/classifiers/attribute_classifiers/train_atr_classifier.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78a4e66e7a05fc26d316c57a52f11a19907ab03e GIT binary patch literal 3616 zcmcgtOK&5`5w0dBN}{O8QV)vKYL~KhH`*9p0Xb|CV-ZKT1KDuk#1fz&gzZ6#(<FyG zubpnmqBC<+{y+`^@*9$4fc%8~iCl6FkV^pnfaI$hQV%<TZ{ee-y1JfURrOTAS*ri( z;lIv$H2GBU{2BiC4_JKs6oo|H8(L9UP5BiXRjE5O-OkdeM%}sTcAmNmQ@T#w#VOsO z?$VTArtZp=UZw5>@+<T^xNDMGBfm<oE0q17Op~sOR4mb)9?T%J;6M0$+EXC+@n6ON z1N`mZaOj9eiZu#FPGSW8d6kL^6<BA4IYW7cle-3UR#ax$IehvvbgfiT6R{eL<q)os zKTGXv=+DW)9OYHY=R|*gqR)=!X@sm{r6yJuzI}A9sHl?<lptBB0a3m{AAYey#Uka4 zG+LxDEA-|sP^Q-LM?P%b$ko3q_xFh$dv4HZiE?06N1TdA%aqTP5336jNkhV3xf6CN z-Jo`{Om3A%4`|e+e3?z!GsT>>$yT6Wq5`$Tx+-3lq4z+j6;Mrb>onS+9IPh!tK@@# z;0j&Ws8}PnNu$=(OP$U7C@P>`<5jWp)<hXBAHVgf{u~wS6aegXDFN#7lF1EfqXI8c zLX_uau^~VMgSJ$DlX5_?EzBL}ASsn2h=+W@C~Z>PZ&Azb%5F<AYjG@on~E0s+f;0k zze7b^f`3GLoANDMAE0MYLzsrO$IO~=cD}=TNLQVT>#z(|KK!CtDR$}VSTPG>GV3R* zcqGu>%^onqU5nORSSBa}g~x)z<69I!PEkNwi~<Zx+IfTWB^scgkzNO`)aTOY#UAAl z-34}g)GEnWkQ{-(!4c)#;u~6RxPRr|qS3zO^@QCbEO=}|Uygy6Y`6c%*`6=5NS#GF z5(ymTV!Cs`752CZt}9jL!{5_yySX1z8=}V_z4>$H*V9bev@y-(Ui9sJhf|d#o{@hy zk6j6T62Je1LE-dD_MVRrcJ48zxqMF^&m<43NvEN7xPQvIDm(rMeC!{vpX-V$-l8{u zL>4F*4sgHm5FZGId-wF<IqtJhaX;|q-|xA3Jd*Q~*Eia^S!`@CNsV17*@wL(zWh;u z4Y?{7?K1lW&oi9``aCqY<A&OGbY?o|CWr@}By~X)d~N)9JC1MBj$Ofe6o#G4sR~Z! z*p#m$FU%wWoOxkjog5?=&hvvT_yj_(&P%3)*c;g-b{tj^>46c6e<Gv+Ns(Q~5*ZzP za`+-n$HDGX9Dais0*=G61m_tgz6rfpM@Hg%$8mgoFbI=#9eUQ8G+?g*zp{eA!o>uv zf4nE|`^Nb2-g6G`1iA`?NQf~L45rnW=2zMcZ4;YI9Uk?1-lg>dXR;Tm-?!UfpTfzz zo!=In&Y2E7khsWnZ_FGi!IiI_c8d@kGvx^Y;=J6$s^wRcE=`*>=?&SoB#9Dvy--^l z^aGQLl{=|P`SJDwB7MgZejFn(-E+aEIW<rM$V*@kGMx^+UWTqnG7(He-7~`^^iAdk zaeCqG4W_*$N<GmK*1$NgmxQ5#RhVPsC8%2|MlMXOm7I>*)hA|L?#_-*K0bQ?7e7B0 zR7Xh=3n2yL_AIgB5RcG6znrvN*+x+iOY=$B7@%~aB?3-2v0^7hWn2w=bA^7h0(bNp z?IwM|VG815qLHWAxos<fEl$M_SJ`dA4y&+V?_9gy(3>c7q^A(cWNNaD*lmFWF&F3% z<KgKr^lm0X+CDQDP#aOOE|bVil)AA{Uz-dWu-Qz<0GJWv3`I2!6Nf5F;h6T&7}8Oa zsV+4*>3)($7;|p<yK(_sv*{jcV%!=i1cM^_rmrtTSC;l*AAHx~MdmSu-KYfBqH$!J zu4ElcSzjBIdTExN3(q@&B1t@BaE&j)_K=kIOzg5G9ZRbM6YIIe!#yI?G{MOVEI7P# zA@HP@GIz7Z?ZMS7FN%TpRp6uGGSAV#8=7D+bdurirQ-F2&~Sg9@c5qmkmF=^pj#GI zLcJM#lr0prxCP31%F<iZRTLwA<;hG+biCu?mmy-CJheK@^tfa*-1pg~E;m;h<gG$^ z!lf!HzS#o7t(25Nq`bG`%86@15<QUt>+(98^eb#j%$$tJ!*KvXlEwZ8)Zs9R%%Rh9 zq{oLJ;O;y;Hu2eKN7KjcO=AzySy^y?;mn=;#~w~{xOXePNXI98u+CSU{SphQMrA`a z)rM+g-NEw!->0fssbjxdt*M%-<NHWGuI{P*>PmI3`b@YrwF!x)TEq9b+E;s`pQ?8% zO`KTucD1gg4xShBYs5QPhr4ph#`feXZysR5T<}PQ!8wAHf#QLgko@XPLx;ESJwbWd zNH5${Q5r-8;l&pb(tH~#Wz(A2CRsmBzB-Z8^5uJT-vkMIGfYi(@|@-QvT*%!?#fFT z7wJ@jODQ#%5uT|17)*h5k}LN!as{_6=KrCh;`tTcuNR^DDR-vb#L`d=wSq4{<y%>J HzoGsM4fITj literal 0 HcmV?d00001 diff --git a/classifiers/object_classifiers/eval_obj_classifier.py b/classifiers/object_classifiers/eval_obj_classifier.py index 6c78aa4..e9320c4 100644 --- a/classifiers/object_classifiers/eval_obj_classifier.py +++ b/classifiers/object_classifiers/eval_obj_classifier.py @@ -5,7 +5,7 @@ import matplotlib.image as mpimg import numpy as np from scipy import misc import tensorflow as tf -import object_classifiers.obj_data_io_helper as shape_data_loader +import obj_data_io_helper as shape_data_loader import tf_graph_creation_helper as graph_creator import plot_helper as plotter @@ -25,32 +25,42 @@ def eval(eval_params): if not os.path.exists(html_dir): os.mkdir(html_dir) html_writer = shape_data_loader.html_obj_table_writer(os.path.join(html_dir, 'index.html')) - col_dict = {0: 'Grount Truth', - 1: 'Prediction', - 2: 'Image'} + col_dict = { + 0: 'Grount Truth', + 1: 'Prediction', + 2: 'Image'} html_writer.add_element(col_dict) - shape_dict = {0: 'blank', - 1: 'rectangle', - 2: 'triangle', - 3: 'circle'} + shape_dict = { + 0: 'blank', + 1: 'rectangle', + 2: 'triangle', + 3: 'circle'} + batch_size = 100 correct = 0 - for i in range(1): + for i in range(50): test_batch = shape_data_loader.obj_mini_batch_loader(test_json_filename, image_dir, mean_image, 10000 + i * batch_size, batch_size, 75, 75) + #print(test_batch[0].dtype) +# print([np.amax(test_batch[0][0,:, :, :]), np.amax(mean_image)]) feed_dict_test = {x: test_batch[0], y: test_batch[1], keep_prob: 1.0} result = sess.run([accuracy, y_pred], feed_dict=feed_dict_test) correct = correct + result[0] * batch_size print correct + for row in range(batch_size * 9): gt_id = np.argmax(test_batch[1][row, :]) pred_id = np.argmax(result[1][row, :]) if not gt_id == pred_id: img_filename = os.path.join(html_dir, '{}_{}.png'.format(i, row)) - misc.imsave(img_filename, test_batch[0][row, :, :, :]) - col_dict = {0: shape_dict[gt_id], - 1: shape_dict[pred_id], - 2: html_writer.image_tag('{}_{}.png'.format(i, row), 25, 25)} + misc.imsave(img_filename, test_batch[0][row, :, :, :] + mean_image) + col_dict = { + 0: shape_dict[gt_id], + 1: shape_dict[pred_id], + 2: html_writer.image_tag('{}_{}.png'.format(i, row), 25, 25)} html_writer.add_element(col_dict) html_writer.close_file() print 'Test Accuracy: {}'.format(correct / 5000) + + sess.close() + tf.reset_default_graph() diff --git a/classifiers/object_classifiers/eval_obj_classifier.pyc b/classifiers/object_classifiers/eval_obj_classifier.pyc index c4e6f97b11143b781aef4f2a8c5d37ac1afe9cfa..48898ca0a27747f6356b9311b4ebe8a2e4f3f745 100644 GIT binary patch delta 426 zcmXw#!7D^j6vn@E-^?>JM$EkTW`@BqLy<xfWrarUNqUiR$LKXhJi9weV`ss1v$3&K z%1ZMGD1U>Evapnm1<BHx!9C~P`|EVR?{@FIed?r*kCw>G*W2<I{+8f#n(N@Kn)Y&T zD<mKZr8c;+e~!boAPt-Vu3fzyrIEYmB^tjW--b+}LddUOQaJ`c2G4;CLq(uMkV&W* zR4eQxZ!*QHrReQwhYDb#Y(Wvc1e6ICL{UI?=qr)6th#SqkcQpxCen{$lREXJF&ALp zq*g|s)d?@v@OMyQZ5OAm8N)-}jeZ+s1`f;T7wWfkp#qQ(c-;-+S1u0p=x^hb#`J4o zg)+67;4x+WApgZyY4@PuTTJ@}hILuk-Z(7z>&5+zy`9>Fc}=vccS4z2gPJ%Z)Iz_+ zEKD-VqCq1^X^K-&Bq_%crCyrisb-FTN{I-?xdt44AG**-VVCmyDSU25c`zz|0Mt`c AP5=M^ delta 437 zcmXv~K}!Nr5dP+EyXmHFx~^+&7G{cFq98g)6m_nMP(sw*%0(#B<9Z0Yb*e`}(6OLv z*dOWGPw>>Cu0^v_Z)V<ne9ZUp=3QAIDM$H?>;7P{-@xAy+}F6;H`k6|^bM3J-~uk8 zl|VW9k4=<Qa1~qvW$XHRiR_p6Jz9CB4%2b;2pflY)J>q1fNvp)BG3^;;HD8I5Ew|= ze7QEKtcl!I2EhncyCwpH&odeV6~_YZq<j!_rV;vvf+?hi-rtaKqE3!nQ8r~)nKvEy z?7`L+dMumfl&{L7ogMPL!wf7|Zf)~2yQNtfANi7+x<}c5MZKW@kM=|dJK-NtOjwkp zq%f#RbEFR=NfR`~Xp}2H>><)POj@8kIjm60W|XHG<!F`0X^aY#5jvI9QjOj&6xjQ; z?t6_+wb81z+s&h<civuPs_ND{wQBRUdg8UtymNUUtwh&(x)eS~E=N3B*Ylde#xd~= D|BF&* diff --git a/classifiers/object_classifiers/obj_data_io_helper.py b/classifiers/object_classifiers/obj_data_io_helper.py index 5a68329..1e9b435 100644 --- a/classifiers/object_classifiers/obj_data_io_helper.py +++ b/classifiers/object_classifiers/obj_data_io_helper.py @@ -1,4 +1,3 @@ -#Embedded file name: /home/tanmay/Code/GenVQA/GenVQA/classifiers/object_classifiers/obj_data_io_helper.py import json import sys import os @@ -9,19 +8,20 @@ import tensorflow as tf from scipy import misc def obj_mini_batch_loader(json_filename, image_dir, mean_image, start_index, batch_size, img_height = 100, img_width = 100, channels = 3): + with open(json_filename, 'r') as json_file: json_data = json.load(json_file) - obj_images = np.empty(shape=[9 * batch_size, - img_height / 3, - img_width / 3, - channels]) + + obj_images = np.empty(shape=[9 * batch_size, img_height / 3, img_width / 3, channels]) obj_labels = np.zeros(shape=[9 * batch_size, 4]) + for i in range(start_index, start_index + batch_size): image_name = os.path.join(image_dir, str(i) + '.jpg') image = misc.imresize(mpimg.imread(image_name), (img_height, img_width), interp='nearest') crop_shape = np.array([image.shape[0], image.shape[1]]) / 3 selected_anno = [ q for q in json_data if q['image_id'] == i ] grid_config = selected_anno[0]['config'] + counter = 0 for grid_row in range(0, 3): for grid_col in range(0, 3): @@ -31,7 +31,7 @@ def obj_mini_batch_loader(json_filename, image_dir, mean_image, start_index, bat if np.ndim(mean_image) == 0: obj_images[9 * (i - start_index) + counter, :, :, :] = cropped_image / 254.0 else: - obj_images[9 * (i - start_index) + counter, :, :, :] = (cropped_image - mean_image) / 254 + obj_images[9 * (i - start_index) + counter, :, :, :] = (cropped_image / 254.0) - mean_image obj_labels[9 * (i - start_index) + counter, grid_config[6 * grid_row + 2 * grid_col]] = 1 counter = counter + 1 @@ -51,7 +51,6 @@ def mean_image(json_filename, image_dir, num_images, batch_size, img_height = 10 mean_image = mean_image + mean_image_batch(json_filename, image_dir, 1 + i * batch_size, batch_size, img_height, img_width, channels) mean_image = mean_image / max_iter - tmp_mean_image = mean_image * 254 return mean_image diff --git a/classifiers/object_classifiers/obj_data_io_helper.pyc b/classifiers/object_classifiers/obj_data_io_helper.pyc index 46ba9ab54418e0b517e20fe9253b2dbea0b4aa99..9bdb6c7015cc3a0da7906325045969549a720070 100644 GIT binary patch delta 443 zcmXYsIZH!95XX1+>ATM)1{07R1Wm+)6bfRDCqyDxJnN7mibf+s-UFN1SlB3{jD?Mz zt%AH>{0M#nTaQw(cGk!+znS^(e|Fi2!dXF)zAfGR`Fv6ulDciNp&n<$I4PC!PD-gZ z>m^uX+rejP0$lu(j-uwkJs*KFD1psoNK)7oGcW}6yaE%D<WW{BL%zfq3uFR=aWb_O z>h%9yr5apfpkCQqMTv&ZwqP4Dt5rb9hE~HHua)*wq)rR7Nnj9%EM#Dc>k!TA3jp(Y z?H%Bop;Gs5yvP+Y`3v}H=AeqN=DD{-LZA{DxR8Pz$BxxCmT4{y5d12BShrBaQ+q$% zUnVsiB(yq5!A=qSrfY8(4;>ro*mI`bMIo*UtP9lf!aGR#bJUCcdrs!spr%A~ZVASP z=x-20gRmDKcACfCcDU2+udj&XzuUw`C!=|qO%UNwi0blKCML)9cv4I{{lkJ^pFm5} delta 482 zcmX9)IZFdU6rQ(}>}EH6<uG1|SHP%22*F095f8+(h?RvbV%B6eiIQEgiKR9+%4=*a z#mYj`tF?vTPq4PM7vCfd@8iw)eeamLPCus={%dIYkM9@dfXgk3^_5MW=6QVJ{Sk|` zk}O%a@s(#GhClo)^v4)$?}r|xA!!08!IWf3(J?8Tg&vrf3Xp|Vs6D9UU@S{!UnXEM zMW!Q#vM7Z{u|nbBV<U1qROT#ssj*gs?PWql?|Pij0|9@is|ASShPs{zOk#rixl?x! z8;(~$thjiqR(o?q`$A)ofDssm8A$_;w+(FppnzXW#fXsTI9|gG9H)qH(RX=)n9?F% z>7!7@C;cY2N`x<b3~cC!QNR_W8C|A(XQ@&V-;I0N#C`K*pj%8*5GFXPo%wD?tV<@E zW{LN8NaxtAxjVRL&DcvqUK6k(U<>c8T4IjUMT*W`8Pu>Uw{3e8ri4)1CJ4lTRfBrX nxoCQU+w4}0&+9f5eNOQ+7Kai1jwP+JcsOn(Qo5wWI2iB`wWeEq diff --git a/classifiers/object_classifiers/train_obj_classifier.py b/classifiers/object_classifiers/train_obj_classifier.py index 292e1ca..45d5830 100644 --- a/classifiers/object_classifiers/train_obj_classifier.py +++ b/classifiers/object_classifiers/train_obj_classifier.py @@ -21,21 +21,26 @@ def train(train_params): if not os.path.exists(outdir): os.mkdir(outdir) + # Training Data img_width = 75 img_height = 75 train_json_filename = train_params['train_json'] image_dir = train_params['image_dir'] mean_image = shape_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width) np.save(os.path.join(outdir, 'mean_image.npy'), mean_image) + + # Val Data val_batch = shape_data_loader.obj_mini_batch_loader(train_json_filename, image_dir, mean_image, 9501, 499, img_height, img_width) feed_dict_val = {x: val_batch[0], y: val_batch[1], keep_prob: 1.0} + # Session Saver saver = tf.train.Saver() + # Start Training sess.run(tf.initialize_all_variables()) batch_size = 10 max_epoch = 2 - max_iter = 1 + max_iter = 950 val_acc_array_iter = np.empty([max_iter * max_epoch]) val_acc_array_epoch = np.zeros([max_epoch]) train_acc_array_epoch = np.zeros([max_epoch]) diff --git a/classifiers/object_classifiers/train_obj_classifier.pyc b/classifiers/object_classifiers/train_obj_classifier.pyc index d803fd2f3ff49fa7edbd7c3cee464cb231ef703c..eac87e3278c2ff50ee06780b0e6328693ac40b63 100644 GIT binary patch delta 273 zcmZ1}u}*@W`7<w<`L|aa*`1ge1txnjEte8xs9|9+tCeP`VPwb>WvF3isNrBpXJSYZ zn#|2yX(r5&!oyG_%upl3AOaE<VMyU-NZ|#m5M?OgV#wlV$P!|x;bce=V`ydqB4)-E z@yWZGqf#UoN(30PWI-BgBtXW91LZ+fiX=mgIFhhpiWEbQ7>WSQS^`!|Pp)8*W0wJ0 zoFY4UE{i&&+~lJyP0X3wm^Zt!zGD$)W@BVy<N`uDMtNpMMrCGx=FM(g4;Y!vco+ei C9x}@S delta 269 zcmZ1{u~LGa`7<w<z{{5#*`1gec_w=?Ete8ts9|9+tCeP`VPwb>WvF3isNrBpXJSYZ zoXpK!X(q&w!oyG_%upl3AOaE<W=P>?NZ|#m5Me0cV#wlV$P!|x;bce=WoTvsB4)-E zvB|rbqf*2fN(30PWI-BgBtXW91LZ+fiUdQAIFhhpiX=mg7>WSQS^`!|O|D>(W0wY5 yoFX%ME{i&&?Bt^?O`CmK->?WXu`#kSasi<nqdcP`qcRge(`Ik32aL>yJd6N)k1<>T diff --git a/classifiers/train_classifiers.py b/classifiers/train_classifiers.py index 6367807..d2edabb 100644 --- a/classifiers/train_classifiers.py +++ b/classifiers/train_classifiers.py @@ -6,35 +6,59 @@ import numpy as np import tensorflow as tf import object_classifiers.train_obj_classifier as obj_trainer import object_classifiers.eval_obj_classifier as obj_evaluator +import attribute_classifiers.train_atr_classifier as atr_trainer +import attribute_classifiers.eval_atr_classifier as atr_evaluator workflow = { 'train_obj': False, 'eval_obj': False, + 'train_atr': True, + 'eval_atr': True, } -train_params = { +obj_classifier_train_params = { 'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier', 'adam_lr': 0.001, 'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json', 'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images', } -eval_params = { +obj_classifier_eval_params = { 'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier', 'model_name': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier/obj_classifier', 'global_step': 1, 'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json', 'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images', 'html_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier/html_dir', - 'create_graph': False, +} + +atr_classifier_train_params = { + 'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier', + 'adam_lr': 0.001, + 'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json', + 'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images', + 'obj_model_name': obj_classifier_eval_params['model_name'], + 'obj_global_step': 1, +} + +atr_classifier_eval_params = { + 'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier', + 'model_name': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/atr_classifier', + 'global_step': 1, + 'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json', + 'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images', + 'html_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/html_dir', } if __name__=='__main__': if workflow['train_obj']: - obj_trainer.train(train_params) + obj_trainer.train(obj_classifier_train_params) - obj_feat = tf.get_collection('obj_feat', scope='obj/conv2') - print(obj_feat) if workflow['eval_obj']: - obj_evaluator.eval(eval_params) + obj_evaluator.eval(obj_classifier_eval_params) + + if workflow['train_atr']: + atr_trainer.train(atr_classifier_train_params) + if workflow['eval_atr']: + atr_evaluator.eval(atr_classifier_eval_params) -- GitLab