From 92f02b56957d36d58b1c5b4d7544de055358ba4a Mon Sep 17 00:00:00 2001
From: tgupta6 <tgupta6@illinois.edu>
Date: Fri, 18 Mar 2016 13:17:45 -0500
Subject: [PATCH] optimized atr training and evaluation

---
 .../atr_data_io_helper.py                     |  32 +++++----
 .../atr_data_io_helper.pyc                    | Bin 4641 -> 4538 bytes
 .../eval_atr_classifier.py                    |  14 +++-
 .../eval_atr_classifier.pyc                   | Bin 2804 -> 3079 bytes
 .../train_atr_classifier.py                   |  65 +++++++++++++-----
 .../train_atr_classifier.pyc                  | Bin 3616 -> 4071 bytes
 .../train_obj_classifier.py                   |   7 +-
 .../train_obj_classifier.pyc                  | Bin 3118 -> 3261 bytes
 classifiers/tf_graph_creation_helper.pyc      | Bin 5656 -> 5656 bytes
 classifiers/train_classifiers.py              |   8 ++-
 10 files changed, 89 insertions(+), 37 deletions(-)

diff --git a/classifiers/attribute_classifiers/atr_data_io_helper.py b/classifiers/attribute_classifiers/atr_data_io_helper.py
index 0f707a9..572704d 100644
--- a/classifiers/attribute_classifiers/atr_data_io_helper.py
+++ b/classifiers/attribute_classifiers/atr_data_io_helper.py
@@ -6,47 +6,49 @@ import matplotlib.image as mpimg
 import numpy as np
 import tensorflow as tf
 from scipy import misc
+import time
 
-def atr_mini_batch_loader(json_filename, image_dir, mean_image, start_index, batch_size, img_height=100, img_width=100, channels=3):
-
-    with open(json_filename, 'r') as json_file: 
-        json_data = json.load(json_file)
+def atr_mini_batch_loader(json_data, image_dir, mean_image, start_index, batch_size, img_height=100, img_width=100, channels=3):
 
     atr_images = np.empty(shape=[9*batch_size, img_height/3, img_width/3, channels])
     atr_labels = np.zeros(shape=[9*batch_size, 4])
 
-    for i in range(start_index, start_index + batch_size):
+    for i in xrange(start_index, start_index + batch_size):
         image_name = os.path.join(image_dir, str(i) + '.jpg')
         image = misc.imresize(mpimg.imread(image_name),(img_height, img_width), interp='nearest')
+        #image = np.zeros(shape=[img_height, img_width, channels])
         crop_shape = np.array([image.shape[0], image.shape[1]])/3
-        selected_anno = [q for q in json_data if q['image_id']==i]
-        grid_config = selected_anno[0]['config']
-        
+        grid_config = json_data[i] 
         counter = 0;
-        for grid_row in range(0,3):
-            for grid_col in range(0,3):
+        for grid_row in xrange(0,3):
+            for grid_col in xrange(0,3):
                 start_row = grid_row*crop_shape[0]
                 start_col = grid_col*crop_shape[1]
                 cropped_image = image[start_row:start_row+crop_shape[0], start_col:start_col+crop_shape[1], :]
+
                 if np.ndim(mean_image)==0:
                     atr_images[9*(i-start_index)+counter,:,:,:] = cropped_image/254.0
                 else:
                     atr_images[9*(i-start_index)+counter,:,:,:] = (cropped_image / 254.0) - mean_image
-                atr_labels[9*(i-start_index)+counter, grid_config[6*grid_row+2*grid_col+1]] = 1
+
+                if grid_config[6*grid_row+2*grid_col]==0: # if the space is blank
+                    atr_labels[9*(i-start_index)+counter,3] = 1
+                else:
+                    atr_labels[9*(i-start_index)+counter, grid_config[6*grid_row+2*grid_col+1]] = 1
                 counter = counter + 1
 
     return (atr_images, atr_labels)
 
-def mean_image_batch(json_filename, image_dir, start_index, batch_size, img_height=100, img_width=100, channels=3):
-    batch = atr_mini_batch_loader(json_filename, image_dir, np.empty([]), start_index, batch_size, img_height, img_width, channels)
+def mean_image_batch(json_data, image_dir, start_index, batch_size, img_height=100, img_width=100, channels=3):
+    batch = atr_mini_batch_loader(json_data, image_dir, np.empty([]), start_index, batch_size, img_height, img_width, channels)
     mean_image = np.mean(batch[0], 0)
     return mean_image
 
-def mean_image(json_filename, image_dir, num_images, batch_size, img_height=100, img_width=100, channels=3):
+def mean_image(json_data, image_dir, num_images, batch_size, img_height=100, img_width=100, channels=3):
     max_iter = np.floor(num_images/batch_size)
     mean_image = np.zeros([img_height/3, img_width/3, channels])
     for i in range(max_iter.astype(np.int16)):
-        mean_image = mean_image + mean_image_batch(json_filename, image_dir, 1+i*batch_size, batch_size, img_height, img_width, channels)
+        mean_image = mean_image + mean_image_batch(json_data, image_dir, 1+i*batch_size, batch_size, img_height, img_width, channels)
 
     mean_image = mean_image/max_iter
     return mean_image
diff --git a/classifiers/attribute_classifiers/atr_data_io_helper.pyc b/classifiers/attribute_classifiers/atr_data_io_helper.pyc
index b958afac0eda1f2f3ff3b1e2e2dc36c210e170a5..457ee8ea56d594f13e29953a5ac035b031255cd4 100644
GIT binary patch
delta 1518
zcmZux&2Jl35dXdPN9^_5>yLH(k&w`aHYP|kpw=iQh_opPsffcOkzx^9CcAOg*lWjL
zgoAcvgv%8Py&IgkfW!qM4oIAO=z$vt{sgXY;)K)#GtW_J)%v}iH*aR%%x~t+^S96a
zy0H@bC#yfW@z;AtIR7PR-J@?({c6RZ0ggc5M9YK|Q!@@Hfx9uZ@{dT8N>X@4vZN$M
zI4L+JPvb5y2i|X$;978We57G)3AY3%gSkMf3{Qh&z{#TNmeE^=lfy-a;+mLK=scRO
z3RqKxTj3*uH#NB?uFlE?Sf_8FfuCwX0Q3Z7cQk}?SYd*(mBZv>Swc5Xa72_+mK6A*
zr4vV2Q@S*!Yg!OPxR`NN<{EAtrnN{%5JxY8UJ{`|urx!cGa{-C+e8@zDfH3^bo4R^
zRf2?oq90Qz!mxz7CL%)xHxL@=Ww}zdb1sb<V<3lwIm#jg2dG6kHi}G4=Ynu?X=CIF
zw`PZ95#=%U9v2_p5m5nQ9^t~f6%`SZfOZtYib@<gBpJ<xMdb@B7l~;uj9B1M69B{<
z#|VR^K<5xH0H*jN>*F6psBX0!5-9SabJeO=Xv)`Om+I##4b`WpFU&&9i^`OTMfXnt
ztxAn@;|95)5z;XIxJy(hZ7Srhzs-p}seaNQnN@0hB&2>Ko@#g+69=b%<KVPa!~ww~
zH)QZ`BTFsxj{G;aztbfSxHsJMlC)BH0u7T$z??q{_@%gmQE;ZXPu=lw60qe|ejmT{
ztYFJHTRPsL!3`RuiS#BzzwLB_j@(Fe<#&m<rJnrrN`^F(Eh5mi#Hv;o>ta*e6uL-?
zWg0me*R-;DRom58#kwZfmWH=4leNtOYzCr24cw03;3#&1Ns)=ggu~m*Oe##OL=@-`
zJHgS3U6MbVp1hGNi!FJO`fbKemHN0p9FA?yNo~g+zw4?LJCopS<SL=(2QM}4=gFqQ
zFYy1!)clnE$Ah*v=ycu5Q!%8-WYC$my}%vYB|g3a$e+xU1aGeVoYXIgv`TF_a(%g-
z*%U7`*c`ntpxyR-FKD-~%ilBiGFSQdH6|~}uZ>mFl)o5Xir3|v+0QdOByOl$HRT`K
zZ^RpNnCsOao=}M#^@kI;!PjOtRY5*5P4SlOn=j^HVc{;5JtjA0RPdAbHt{=hEiWr|
z;uK20U3!^xnP(F_CGY2NmY?Ah+*B$E)njyK$K^_)v9C@9-oUlhRdNXzHDJd!|CI&Y
x7bn4>-~ME*O0>tC>r50H2J+eEi?lZc50eJ3d0}Y9^<tu^mzE995E~*Z{sR$^1`PlJ

delta 1636
zcmZuxO=w(I6#nj;zs$^=pG@BT<gY48>9c6DsHxIalcq?qTJQ;`4c0K7_hy_iGm~VV
zkidHcP2EUQ<St!^J6*f$s?d!~m*QGnxGIPR35egh6G92jefPX`zkBXE-#O>Ld9d*P
z;%wrtLN52@&ktHS{-yX{;T+lD%nl9`@cAU+r*JQUjgnc8k+ea|PazqDOv0alPnpF%
z;1IaIQ9&>TKZj2&>`thl0zZ#Kg$)N?3%(7%fTh4euL{2ipFv7EWRNmyOB*%FP8~r_
zjwo3)^m+Bl!8A$cIoE{jmlhBL5#TEvTKKw#cmkn9&q6dIC*v&Eec-)YVFEoS#J~we
zNjOnT60An)@hgK)lB`vngIIF&#F9nmSY|Jey@Rq9rV-JHaZ_i|Yj=)RY$ME|H-TOj
zkwTb5WJ`EI4Kaxe2=nOCsDNG(Q9<^|fYIBGlabmTS}HD@;3Y&w^jP97rjNB8riFvT
zfsl?%=rHRnL(d^DBcg4dwzettk*q(muycfX5|L@iG|N35vjbVgQ`oz%m;`k?GYMQ$
zv4gz}r_gy7R}ndg*cwuC;u@j@B3zP>6W4`jM3B$er2Yl#PMFWwNC1|f50!R;92zD?
z@-peruiTD*Rn(B{P>ph?Kem*&;Y3rYpPU{mI-PK4j%_V8PTT9ht>-=@yC{F!PsHTO
z)`7!o7FFTYX;z<R`Ii|+bQWf3U9In|a(;$&x~3QIzwefL*KYDw0lWH-#M+H;f@-(l
z>;#)#pOpk!!$G^-VN>{fI9%xp#!C*k6^bxPG29LYM(K@)V`XdD^z~|TE-4?MzMH(Y
zC?|}~!0+~5x$3+yfSlNDcejE;vmfZ+lAYP?(`b0I@|$5(ucY3s3)dZUb`)#{tuXL6
zn}fkn-%tI1F~_~p3ly*_s%o8AbJmP{MxEn$UKOphs&N!KUR70f*;-cmyY%qpOSJdo
zs3%{aA=wX_16N{ubrB5_Ga_b1%!!y6afTw49Pc*6*2kWtr5~p{`ccMFuKqRi^MUv^
z7VWL!aM!cB^dxap+vX_T-ww=ncMx7&@?PeSYp9&W<p&20KJ9Osu8jVdfi$h(+}rGi
z!LA|UEr=sdy_jvKgw~5c-Orv?DK8`7gWS2~Wl@CdJ9%5Zrf=sTy32&TD<TZT1T{9s
zFl`Ta`_0f5$a_Pt+Fz+F`WO52{3-!g?f}s{g>TiG{-@BJIdNoCYi$ijfh&EZQC!?p
zZ|Nt+3)MFTxF+Jdh_@+R==p3f^*Yr{`iGL966T&{$A2!ic7Nm*wO77cJuCQ>;8J{6
zXM5kv=^x8(%`@GUI*tn=zPuayaoK5@Fjr-5qK*S{X{vBIIl1^O7q*g7;*fXQs+_5$
OD!HkeZP{u;71V$3dmys_

diff --git a/classifiers/attribute_classifiers/eval_atr_classifier.py b/classifiers/attribute_classifiers/eval_atr_classifier.py
index c3222b1..6431390 100644
--- a/classifiers/attribute_classifiers/eval_atr_classifier.py
+++ b/classifiers/attribute_classifiers/eval_atr_classifier.py
@@ -1,4 +1,5 @@
 import sys
+import json
 import os
 import matplotlib.pyplot as plt
 import matplotlib.image as mpimg
@@ -15,6 +16,7 @@ def eval(eval_params):
     x, y, keep_prob = graph_creator.placeholder_inputs()
     _ = graph_creator.obj_comp_graph(x, 1.0)
     obj_feat = tf.get_collection('obj_feat', scope='obj/conv2')
+
     y_pred = graph_creator.atr_comp_graph(x, keep_prob, obj_feat[0])
     accuracy = graph_creator.evaluation(y, y_pred)
 
@@ -23,6 +25,13 @@ def eval(eval_params):
 
     mean_image = np.load(os.path.join(eval_params['out_dir'], 'mean_image.npy'))
     test_json_filename = eval_params['test_json']
+    with open(test_json_filename, 'r') as json_file: 
+        raw_json_data = json.load(json_file)
+        test_json_data = dict()
+        for entry in raw_json_data:
+            if entry['image_id'] not in test_json_data:
+                test_json_data[entry['image_id']]=entry['config']
+
     image_dir = eval_params['image_dir']
     html_dir = eval_params['html_dir']
     if not os.path.exists(html_dir):
@@ -36,12 +45,13 @@ def eval(eval_params):
     color_dict = {
         0: 'red', # blanks are treated as red
         1: 'green',
-        2: 'blue'}
+        2: 'blue',
+        3: 'blank'}
     
     batch_size = 100
     correct = 0
     for i in range(50):  
-        test_batch = atr_data_loader.atr_mini_batch_loader(test_json_filename, image_dir, mean_image, 10000+i*batch_size, batch_size, 75, 75)
+        test_batch = atr_data_loader.atr_mini_batch_loader(test_json_data, image_dir, mean_image, 10000+i*batch_size, batch_size, 75, 75)
         feed_dict_test={x: test_batch[0], y: test_batch[1], keep_prob: 1.0}
         result = sess.run([accuracy, y_pred], feed_dict=feed_dict_test)
         correct = correct + result[0]*batch_size
diff --git a/classifiers/attribute_classifiers/eval_atr_classifier.pyc b/classifiers/attribute_classifiers/eval_atr_classifier.pyc
index 0be14f65263aa832ead7df9c9f5988c223c072a2..960456aa4a4ea6b14e9af628b1864cfd74c89fa3 100644
GIT binary patch
delta 1196
zcmZ`&OK&4Z5dNy|iDSo)_!&>ejuUT)5X%LDB7r!7Hq!2utVqzHRq&>h$P;f~Ziz(d
zkwJk)AR!dZ4gLUUICJ7hpuHjSIB`L6<iOVxTLKA2J>6B^UsZio)%|DllXj5#qg43m
z-fuhmn0|`u8~n##-305CP(cG)Fv=jxBJ^hJ9HJb;`I(wWSV(1s(E^?n5f&j+%M!x!
z3|B$8IFqXgYcsiyumPRHM?lzg@g6>(C5fAF{Bn=$0{?rC)b}^fewdgl>;iNV*9u>}
zjkpL?LsWw0jvy|<)Nu{OWn8`tTf)|(yND~W73dOtTgGt?7oIr1oa)n)*UvrS*xe=Z
zx8AYLy8=I6gp(F2Rbk4ADicHv*s6;)J`=aqoT_pHm=>a?TOcCZR-p-1DYF`e8E7W$
zVN=(k8(f%e3!8BrmN~RxNnIerEcDVXu6>L1Q=DFgo&?0RS{=kR_7Jyl`L!4Vy3Tg%
z(U%gF0eLKg$y$<NZ&qNJV4G+xane0Eu5*anj*?{zkCjl^bYg>IhSUT$%TPY7APQjT
zVRJZA6X;b~pRpungcdAGIdkWnWVmzx!DZbsa-r7HNLdjbz!VX6(s(XNkhUNASm98?
zCoJyFAoTe^kT<)q^wGt!!uY95don9^9}!QPXpq)iAzp!5LDYj~0j6mN$p@JiC+uMd
zzH5c0*rmvJxF)PZ|J#ke1CDF9?ms$p{#0@7p8qbp&Unq;JkjeZj~4WEXfmz)4_pcR
z>6rl{ddoW)c+WSISF)dVg}3{{%yS+-85|w#4Yir0vVVGVI6O+e%U)HCNB!|=aB|F!
z0m-JfBkRfM-p|YKrJSA}oERrW<I`I~5BE)SHusD4lAYYATZRGC!@a}d=)jbyT@2pe
zbMQSqFaws6m-k?7{;}mE((T37wtIlv#czMjb>yxy9gjFcimE9!X)0f8sv;YzF9qrq
z$xBHVrKgg${BrVY{-t`0Xya&)ug2ROT(|qlW&S-`O`hfZ(n$Ww2l7JFF04r_d9AQ2
S50iHbAI>?eq9i{S8h-;kjN#7!

delta 949
zcmZuvO^Xvj5Pj8?Y_ju_k4a`HlgWOItSATyISUF22p)EE7X}o;M%i%{mdqhS4?(gb
z>qVBn`3t;w@?ZD|ym;^sFMHLCC#{}M4x%&NGq1Z}RlTa}H=UnBZ2j=;?QiG%k1>A)
z|8Maey*rC9O;QU<2@;TmqcYMZB$dUoie$OK0!KEUJ4hVRoa!R+7Q7mg`T{qQG#A)M
z(t@;b0wiq{U~i|yfe(N4b8m+jW@n$hTFk)ad7I~G|LpshiDEe}qTGc+9wQG~4QU-|
z6UxJsFtW+)`12*bU}r6)?FFx44rC2eVc2;MhZZEOETb<QkWK#ikS+A{Iutt!pz4N6
z%$rD7Yly)Gapi)@W+8IUu8iEr_yc8p&@o1L9=xZ#RR{>794HoULq<?7s3zJWe`ME;
zX$5)PXi-IoND1@}Y$8$?M<l2cCjFe{5Yi4*6{>=#0y2gQI2PqAP@FY~F0+ks&SH3H
zFdVmxA`_~IcA<*dAu<;!E13rgi{=2Y1)c~*rZ$}gg?$<w;oiR-cEDmHJQEm2V&033
z8m}Vl7Kw&Mu?6xjvKVO(iudPfhWeLvZNdf}M5eTgVZ@O8V5+|g`L`Rn1|C!H*?(Ge
z90?u=$0u}PjF|_%m+OTNZt4aCPHQkU6<jaQG9cZ+blK3rw{TN`EWOqD%VYhuyic3@
z=F&{xsLUwVVfD1KTjh(myZmWZ?a{hxF6Bf1fLsxfPXR^bi#lBvTV#{IYIpU(zFrHM
zky(cPa<sdqPwhw4)93aUHMHl%w5_i=o8;@e&IS$jk@Kq4X0A&v)%kU)!LxY9AMzWg
ATmS$7

diff --git a/classifiers/attribute_classifiers/train_atr_classifier.py b/classifiers/attribute_classifiers/train_atr_classifier.py
index 5145bae..0c59564 100644
--- a/classifiers/attribute_classifiers/train_atr_classifier.py
+++ b/classifiers/attribute_classifiers/train_atr_classifier.py
@@ -1,5 +1,7 @@
 import sys
+import json
 import os
+import time
 import matplotlib.pyplot as plt
 import matplotlib.image as mpimg
 import numpy as np
@@ -14,15 +16,30 @@ def train(train_params):
     x, y, keep_prob = graph_creator.placeholder_inputs()
     _ = graph_creator.obj_comp_graph(x, 1.0)
     obj_feat = tf.get_collection('obj_feat', scope='obj/conv2')
+
+    # Session Saver
+    obj_saver = tf.train.Saver()
+
+    # Restore obj network parameters
+    obj_saver.restore(sess, train_params['obj_model_name'] + '-' + str(train_params['obj_global_step']))
+
     y_pred = graph_creator.atr_comp_graph(x, keep_prob, obj_feat[0])
     cross_entropy = graph_creator.loss(y, y_pred)
     vars_to_opt = tf.get_collection(tf.GraphKeys.VARIABLES, scope='atr')
+    vars_to_restore = tf.get_collection(tf.GraphKeys.VARIABLES, scope='obj')
+    train_step = tf.train.AdamOptimizer(train_params['adam_lr']).minimize(cross_entropy, var_list=vars_to_opt)
+    
+    
     all_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
+    vars_to_init = [var for var in all_vars if var not in vars_to_restore]
     print('Variables that are being optimized: ' + ' '.join([var.name for var in vars_to_opt]))
-    print('All variables: ' + ' '.join([var.name for var in all_vars]))
-    train_step = tf.train.AdamOptimizer(train_params['adam_lr']).minimize(cross_entropy, var_list=vars_to_opt)
+    print('Variables that will be initialized: ' + ' '.join([var.name for var in vars_to_init]))
     accuracy = graph_creator.evaluation(y, y_pred)
     
+    # Session saver for atr variables
+    atr_saver = tf.train.Saver(vars_to_opt)
+    obj_atr_saver = tf.train.Saver(all_vars)
+
     outdir = train_params['out_dir']
     if not os.path.exists(outdir):
         os.mkdir(outdir)
@@ -31,23 +48,32 @@ def train(train_params):
     img_width = 75
     img_height = 75
     train_json_filename = train_params['train_json']
+    with open(train_json_filename, 'r') as json_file: 
+        raw_json_data = json.load(json_file)
+        train_json_data = dict()
+        for entry in raw_json_data:
+            if entry['image_id'] not in train_json_data:
+                train_json_data[entry['image_id']]=entry['config']
+        
+
     image_dir = train_params['image_dir']
-    mean_image = atr_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width)
+    if train_params['mean_image']=='':
+        print('Computing mean image')
+        mean_image = atr_data_loader.mean_image(train_json_data, image_dir, 1000, 100, img_height, img_width)
+    else:
+        print('Loading mean image')
+        mean_image = np.load(train_params['mean_image'])
     np.save(os.path.join(outdir, 'mean_image.npy'), mean_image)
 
     # Val Data
-    val_batch = atr_data_loader.atr_mini_batch_loader(train_json_filename, image_dir, mean_image, 9501, 499, img_height, img_width)
+    print('Loading validation data')
+    val_batch = atr_data_loader.atr_mini_batch_loader(train_json_data, image_dir, mean_image, 9501, 499, img_height, img_width)
     feed_dict_val={x: val_batch[0], y: val_batch[1], keep_prob: 1.0}
     
-    # Session Saver
-    saver = tf.train.Saver()
-    
+  
     # Start Training
-    sess.run(tf.initialize_all_variables())
-
-    # Restore obj network parameters
-    saver.restore(sess, train_params['obj_model_name'] + '-' + str(train_params['obj_global_step']))
-
+    print('Initializing variables')
+    sess.run(tf.initialize_variables(vars_to_init))
 
     batch_size = 10
     max_epoch = 2
@@ -57,18 +83,23 @@ def train(train_params):
     train_acc_array_epoch = np.zeros([max_epoch])
     for epoch in range(max_epoch):
         for i in range(max_iter):
-            train_batch = atr_data_loader.atr_mini_batch_loader(train_json_filename, image_dir, mean_image, 1+i*batch_size, batch_size, img_height, img_width)
+            print(i)
+         #   start = time.time()
+            train_batch = atr_data_loader.atr_mini_batch_loader(train_json_data, image_dir, mean_image, 1+i*batch_size, batch_size, img_height, img_width)
+        #    end = time.time()
+        #    print(end-start)
             feed_dict_train={x: train_batch[0], y: train_batch[1], keep_prob: 0.5}
             _, current_train_batch_acc = sess.run([train_step, accuracy], feed_dict=feed_dict_train)
             train_acc_array_epoch[epoch] =  train_acc_array_epoch[epoch] + current_train_batch_acc
-            val_acc_array_iter[i+epoch*max_iter] = accuracy.eval(feed_dict_val)
-            plotter.plot_accuracy(np.arange(0,i+1+epoch*max_iter)+1, val_acc_array_iter[0:i+1+epoch*max_iter], xlim=[1, max_epoch*max_iter], ylim=[0, 1.], savePath=os.path.join(outdir,'valAcc_vs_iter.pdf'))
-            print('Step: {}  Val Accuracy: {}'.format(i+1+epoch*max_iter, val_acc_array_iter[i+epoch*max_iter]))
+        #    val_acc_array_iter[i+epoch*max_iter] = accuracy.eval(feed_dict_val)
+        #    plotter.plot_accuracy(np.arange(0,i+1+epoch*max_iter)+1, val_acc_array_iter[0:i+1+epoch*max_iter], xlim=[1, max_epoch*max_iter], ylim=[0, 1.], savePath=os.path.join(outdir,'valAcc_vs_iter.pdf'))
+        #    print('Step: {}  Val Accuracy: {}'.format(i+1+epoch*max_iter, val_acc_array_iter[i+epoch*max_iter]))
         
         train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] / max_iter 
         val_acc_array_epoch[epoch] = val_acc_array_iter[i+epoch*max_iter]
         plotter.plot_accuracies(xdata=np.arange(0,epoch+1)+1, ydata_train=train_acc_array_epoch[0:epoch+1], ydata_val=val_acc_array_epoch[0:epoch+1], xlim=[1, max_epoch], ylim=[0, 1.], savePath=os.path.join(outdir,'acc_vs_epoch.pdf'))
-        save_path = saver.save(sess, os.path.join(outdir,'atr_classifier'), global_step=epoch)
+        _ = atr_saver.save(sess, os.path.join(outdir,'atr_classifier'), global_step=epoch)
+        _ = obj_atr_saver.save(sess, os.path.join(outdir,'obj_atr_classifier'), global_step=epoch)
             
     
     sess.close()
diff --git a/classifiers/attribute_classifiers/train_atr_classifier.pyc b/classifiers/attribute_classifiers/train_atr_classifier.pyc
index 78a4e66e7a05fc26d316c57a52f11a19907ab03e..500169dccf17b7b3efcb303cb4ea2258248b13b3 100644
GIT binary patch
literal 4071
zcmcgvOK%*<5w70lF2$EfF5eO*E=5tItXNE(KnVgNwk#X6;74K#k&=a-!EmR!?A7ed
zXr`Cq8Z#&5luNF;<s1Y#BnSVV9CHYe0R92-SJg9I($+bA^j23@*YoRYvi{B7^xxnA
z_p>gQUp0Kbhu`dPSOWYg3W+*rw4#ox_%-U(D!NXcdPPrAzd@a5wVkB?CF)F7+iB{|
zRP-!$<|_I!b>=I2fjWy7-J;GC1vUCT!dND~LP4Eg)F}N6>1BFNq#}v(>D~k)Hvhxd
zXI%yIE&MLx_Yr=x4;w_5sIT}$p~%*#UuXJ@Iu$i4uucebf^y^pH82};q`}VNGnk;)
zN)=6EHmN@;rY0$u{BlJ(;%LxD(4;jxCA&%5c&%s+wiG*@X4^SDpfr~tg40XlbcXUe
z<(I_iROytXouxj?1J_gHdiqVGX%Wp-QC$|tGvc^EPdUmyOSw~;{qA`LSIkk5D4(Mb
zzFemQr|~mGFD5D7QlNiVqnBH%Kw}WdW$G`8_B6ETOYQ!mD1E9p)5jkLP(9O||GwU^
z)Z==z&ZWIjC4ho*v{Wpf<ID*M`Y7=NZ4^snm#Du?{T0fWSeE@qF=w^h3MLk*Sf&6g
zlyI(vbE{<6C|{)jhidf4UsJJ0cAfegl&^`4b+T8auN!pwW6IZ&z+Eb?P=1AiW!lI$
zq%`+yykC)hDS7bnehoj*pX1Ja{u}nU3V&C{^BO?0E|~pmE!d#PAK9za-y|5zuTlU@
zP(zGBbCZfqiq<LEWUXrw7)-$cS7eizb-l`cfs+lcaVGgDuPV4+8s937gCA^B(5~!H
zQgMw8uy#!l2sqwnavPb{={{f&blxwnQ*Z-@0O@{Pny^jzJeLPQD%ha{CT~*FreK$f
zZ3*fI<xpzVa!<618^YaTR>|4<9_J>V{z=)JtOLUzeEA1e?1Dwi+AdjC+@k!}8?u5o
zNVaHs4NF-)WOrM#yFJbhWR)HCIlDb_iRHK#fHUs)ZOZrHPkkYGzPLj<BG^Su?$DYe
zAH4iW{dvXRrj?SLBW+mS<#?CZHge><OAqWW^>0zICkaa^;35j11N;$)15ALDy&b$o
zIC`FOmle5oIPn9PpfyateX;lVxa+T}sq#g}1@bDdg@eXEi-RAspVu{2G$?PqN$p49
zReSkoB#Y*wC>XCii0%jlXg}S12jlcJj12zz@4NOAKH|f^ccgu5F=Q$0#);0R9DBbT
zM<?&t%casF4s_^6{y<xv&~MxMQqT+ILqGH~s}lnj**X^Art%5%c9QpjKk&j-T%P!;
z7n;mwHz9cDr^Y`Fb=J1WzHR%dZXaqB_1bY_&A@!6gP*s>Vf#($uS^(1zilF8jUSfk
z>@Xf$FEFY37_^<{CG}0@^|Lq<b1oD!@O#=bfo(zq`5u`b#y_4S;p4m<9cVxDgoS#d
zO3PRt#e-yMIZ&3iMQZ+q)D0?4SqsZ2u^(Ji$heqeYxkmLl&zi9JV6S9Z%rJv@!^|m
zyRi2S%5Avh*(#PNRi<nim2|~3oT1_O-+YZ2$`=3gT}kFh>j1Urx*nghWjh1a5mTPj
zMfBVrG1v)6PeU^hV#EYPCG$`8uY7x)Enwq!yWUCW=_KwR-%Elc!3nzVbwfYP%#qQl
zI5~$7%SB<lT~yoS#~x1#mLpMl5?P)4U29JCQw>|_*c_z2)K8ARZVIr*sVF9)-_^%)
z80ge95!#WRp|y+lcydBmJ#D>i9EKWJVUDrx@@DEO@=67zsm@?q3vx3cnA;`6j;(QN
z!Z^#MVvjkbkM+pyo;^Hx^6>pnK71<SJ_O=_Sz_3MZ2*)qqNS?+*hE6ejpFucRDLMU
zmtw@3h!YfC(5X-P%1G#g7mQ4f1lQGB5nQ|2{6Iv~aK59QfOI+p9(8)a-4hQ{B8uEr
zyHYB0jNYMdyT_hWms+I5$TH~1*UUSEo<yVvi5&^`l}-^kmo@bxl%4Nl62=xSm5IT$
z2k;EhVS{q>#q*siAL^pE8iNqWH3%=zNB%Ihu9Xfcvk<z=oZ#TVl_>*I<X)f3I+FIh
z&^qywG(HrbcYuQ=0tO+cor?PrbaWt1>mqo7I<je;j0C^PGxKchVH`>P)xLttc(m{U
zlzZZbA(tnU(t)rXadnPM=(Jo7T%_D17Z4-8x+CKdG0nK;ZZ!40SD0AGGMdbw=N)U)
zJGN4^@sRV5OsEB0e2C#89LX~EU&$a8SV&gq&dGS-R&W(Y2aF&G8MXt6I&_Ggv%!LN
zLkY(rCu5Ma`oKSR6GAfPonajC9AHa4L;N)LM~=;1zsQ#9BG_~Rjt@AAgvk>~8f>It
zPE!Jr)W$)I1b;&UWy3T@dT#cWaEI+Zh)9;~x1jcq<AL6{el+k$`;YMa?LXGhv(Fz^
zALB`v?c)JW&Ee4MbMr6TFQdLN74IdZgB!5Ujn445CSK{;+KOtamHK@BR&4{+RebN(
zSJbZB!*{FRQmxuF_zU%>YAUAY)s&i33u>!AjkTe+YYnx89VoRk(NL>3)@-U(=(ZqP
zQjOZ3`VOS4jTX3>x~=Zwtn|h62K+R7d;Gq***D4#euB-xK67}i;4Rjc|05gtmalMV
z4v~mlY`(!I$uEBrTFWEmoT7U=xmjaskN~<eC!*m1CHWbo?w!_A7N<vH{OVwnc?_1b
z5oThbJxnmge#)|SUXI*wa4!R<d!=!mh5XLR$SIG#hrN`)%$#osG6&Z+{{PT@@%#XP
bP7Oo-3whVC!27In-}C5gMbFlzyjk@v*e<|h

literal 3616
zcmcgtOK&5`5w0dBN}{O8QV)vKYL~KhH`*9p0Xb|CV-ZKT1KDuk#1fz&gzZ6#(<FyG
zubpnmqBC<+{y+`^@*9$4fc%8~iCl6FkV^pnfaI$hQV%<TZ{ee-y1JfURrOTAS*ri(
z;lIv$H2GBU{2BiC4_JKs6oo|H8(L9UP5BiXRjE5O-OkdeM%}sTcAmNmQ@T#w#VOsO
z?$VTArtZp=UZw5>@+<T^xNDMGBfm<oE0q17Op~sOR4mb)9?T%J;6M0$+EXC+@n6ON
z1N`mZaOj9eiZu#FPGSW8d6kL^6<BA4IYW7cle-3UR#ax$IehvvbgfiT6R{eL<q)os
zKTGXv=+DW)9OYHY=R|*gqR)=!X@sm{r6yJuzI}A9sHl?<lptBB0a3m{AAYey#Uka4
zG+LxDEA-|sP^Q-LM?P%b$ko3q_xFh$dv4HZiE?06N1TdA%aqTP5336jNkhV3xf6CN
z-Jo`{Om3A%4`|e+e3?z!GsT>>$yT6Wq5`$Tx+-3lq4z+j6;Mrb>onS+9IPh!tK@@#
z;0j&Ws8}PnNu$=(OP$U7C@P>`<5jWp)<hXBAHVgf{u~wS6aegXDFN#7lF1EfqXI8c
zLX_uau^~VMgSJ$DlX5_?EzBL}ASsn2h=+W@C~Z>PZ&Azb%5F<AYjG@on~E0s+f;0k
zze7b^f`3GLoANDMAE0MYLzsrO$IO~=cD}=TNLQVT>#z(|KK!CtDR$}VSTPG>GV3R*
zcqGu>%^onqU5nORSSBa}g~x)z<69I!PEkNwi~<Zx+IfTWB^scgkzNO`)aTOY#UAAl
z-34}g)GEnWkQ{-(!4c)#;u~6RxPRr|qS3zO^@QCbEO=}|Uygy6Y`6c%*`6=5NS#GF
z5(ymTV!Cs`752CZt}9jL!{5_yySX1z8=}V_z4>$H*V9bev@y-(Ui9sJhf|d#o{@hy
zk6j6T62Je1LE-dD_MVRrcJ48zxqMF^&m<43NvEN7xPQvIDm(rMeC!{vpX-V$-l8{u
zL>4F*4sgHm5FZGId-wF<IqtJhaX;|q-|xA3Jd*Q~*Eia^S!`@CNsV17*@wL(zWh;u
z4Y?{7?K1lW&oi9``aCqY<A&OGbY?o|CWr@}By~X)d~N)9JC1MBj$Ofe6o#G4sR~Z!
z*p#m$FU%wWoOxkjog5?=&hvvT_yj_(&P%3)*c;g-b{tj^>46c6e<Gv+Ns(Q~5*ZzP
za`+-n$HDGX9Dais0*=G61m_tgz6rfpM@Hg%$8mgoFbI=#9eUQ8G+?g*zp{eA!o>uv
zf4nE|`^Nb2-g6G`1iA`?NQf~L45rnW=2zMcZ4;YI9Uk?1-lg>dXR;Tm-?!UfpTfzz
zo!=In&Y2E7khsWnZ_FGi!IiI_c8d@kGvx^Y;=J6$s^wRcE=`*>=?&SoB#9Dvy--^l
z^aGQLl{=|P`SJDwB7MgZejFn(-E+aEIW<rM$V*@kGMx^+UWTqnG7(He-7~`^^iAdk
zaeCqG4W_*$N<GmK*1$NgmxQ5#RhVPsC8%2|MlMXOm7I>*)hA|L?#_-*K0bQ?7e7B0
zR7Xh=3n2yL_AIgB5RcG6znrvN*+x+iOY=$B7@%~aB?3-2v0^7hWn2w=bA^7h0(bNp
z?IwM|VG815qLHWAxos<fEl$M_SJ`dA4y&+V?_9gy(3>c7q^A(cWNNaD*lmFWF&F3%
z<KgKr^lm0X+CDQDP#aOOE|bVil)AA{Uz-dWu-Qz<0GJWv3`I2!6Nf5F;h6T&7}8Oa
zsV+4*>3)($7;|p<yK(_sv*{jcV%!=i1cM^_rmrtTSC;l*AAHx~MdmSu-KYfBqH$!J
zu4ElcSzjBIdTExN3(q@&B1t@BaE&j)_K=kIOzg5G9ZRbM6YIIe!#yI?G{MOVEI7P#
zA@HP@GIz7Z?ZMS7FN%TpRp6uGGSAV#8=7D+bdurirQ-F2&~Sg9@c5qmkmF=^pj#GI
zLcJM#lr0prxCP31%F<iZRTLwA<;hG+biCu?mmy-CJheK@^tfa*-1pg~E;m;h<gG$^
z!lf!HzS#o7t(25Nq`bG`%86@15<QUt>+(98^eb#j%$$tJ!*KvXlEwZ8)Zs9R%%Rh9
zq{oLJ;O;y;Hu2eKN7KjcO=AzySy^y?;mn=;#~w~{xOXePNXI98u+CSU{SphQMrA`a
z)rM+g-NEw!->0fssbjxdt*M%-<NHWGuI{P*>PmI3`b@YrwF!x)TEq9b+E;s`pQ?8%
zO`KTucD1gg4xShBYs5QPhr4ph#`feXZysR5T<}PQ!8wAHf#QLgko@XPLx;ESJwbWd
zNH5${Q5r-8;l&pb(tH~#Wz(A2CRsmBzB-Z8^5uJT-vkMIGfYi(@|@-QvT*%!?#fFT
z7wJ@jODQ#%5uT|17)*h5k}LN!as{_6=KrCh;`tTcuNR^DDR-vb#L`d=wSq4{<y%>J
HzoGsM4fITj

diff --git a/classifiers/object_classifiers/train_obj_classifier.py b/classifiers/object_classifiers/train_obj_classifier.py
index 45d5830..debc76d 100644
--- a/classifiers/object_classifiers/train_obj_classifier.py
+++ b/classifiers/object_classifiers/train_obj_classifier.py
@@ -26,7 +26,12 @@ def train(train_params):
     img_height = 75
     train_json_filename = train_params['train_json']
     image_dir = train_params['image_dir']
-    mean_image = shape_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width)
+    if train_params['mean_image']=='':
+        print('Computing mean image')
+        mean_image = atr_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width)
+    else:
+        print('Loading mean image')
+        mean_image = np.load(train_params['mean_image'])
     np.save(os.path.join(outdir, 'mean_image.npy'), mean_image)
 
     # Val Data
diff --git a/classifiers/object_classifiers/train_obj_classifier.pyc b/classifiers/object_classifiers/train_obj_classifier.pyc
index eac87e3278c2ff50ee06780b0e6328693ac40b63..b0dd5d511a461b073a33139526028efddda13b3c 100644
GIT binary patch
delta 1088
zcmah{J8u&~5dL=0Ifpwtex09j{0Imk5CsTDyhOo6NfD(81!<76w8r>wY{xzz#hzTa
zp|$QO(5Fd_XhKIzNk>6}XlN*zx!90INa%EPbKiV3^UdrY6y6u3#?QFrUHd+bev;+L
z^z+eJ1Ik2EMe$N%_zP+g>l@w@+yrzCIu3OX%7V6`Elf=KOXwA#Y-T0lrZ6io5gcMD
z4bOsW!_7cbZVVF%y~I{{4&3alE<q}ev09qfj>1o%V?ZaM5-4hkHFlHa8dy`*DS|aW
z2@h~hxH(J!pHKleAA$u&NQY2q_-UvN{0x)>-wAaUp&jT9oVH+=pJjB%EipEOav|{~
zlE;2cddtkAoQF@-(l20tT`=SvL$`!q#M#M`!3SBrIv=C-cMKYz$Ds63w~j(DqAj6v
zfo)tIm!XqfRuhi|7r2X)zw52Qt%TyLyc`PAUh?xWtv3A4{v%dbc@C0(uEdS#^k}WA
zuqg>Yc~+$xd`uH*2(=8icDB!HQAMJzZ|n&fd^IkQNnjnSd_p7jA&oO_Y6Y5vko04?
zt}4*EQ@~v3e>7G@i7<su@m-@e@(4Ho+_xIXRh=76r0^@)`yrn2vRC<f2);e%Lr(#Z
zFRcb2Blotb0<ha^_8OhtX1hfkVC0Z_(BB=r@;beCjia?N+WBSx9m>WiO=h#-RHq4t
z@r9TW7lM*`bUQ(erZ;S;rq^tAiP{=&M+oyIncl!-pW)CMHD9+v$L)7|!51@sSdLDG
zh#Vd)y5f44s5ZT$9WoFxnH6bzsv<2d`X&uin8KnbDHcRsE{IHU-`WzzU}`<fRwyN$
z(rfib{o(U&|IKzHFzrWTHQ2Q4%PEd?^+C7qHD0v30}^_L{<pKiTl<}HCnkOa4u8-H

delta 923
zcmah{J#W)c6g|%lNo(iZPMkDpK02TXi46jXj|o&Jgv!8#f~9ArvC<?h7V7Y<sLGC1
zEdBsYEG(H}>d4Z8i9di45(9#ndrbmuL1MA)JNI1QbI-kRt8XeP{aePk^5gSi>b>?w
zKm2es4YKfaP-&<PWD+t3WkMO4B@t#3<lz?(81PN_MW|KC70fiKEPn+TTZ>|XB&CgI
zeF|X?105;{X`-q$bU92@sG}=0PziLyX&wRKC*j+e0bv13z%RvMnUR@*kXeL9$Q(io
zGLO)XbvaNrlm)vlm=#hHsjV=yK+1n14)$AGP-PC4+<sK-Xki8Wn}RV?z}PhrR&jQ+
zbo5ngUR=(x_Gb<f%yZNhx)u4ig1!b>jIQhFr)yAo{;euOQ43OM!**^2HTd;d-8#EN
zDKf@crpem)S^7t=Ey)Ag7THdJ;zl;J$WM;RgOj@|U*a{*G+jsse&cMP)1s<G-MqVd
zEFwE`VM+p<kd+e}E3v4BHc7)cvAkb!ThgTEQ@~Q^e>B!&iAV}n;9H~h^9Wae-&>vY
za-F@KX(DW3_nmmiN7CZ+5PUuDIFi7FP9yrBxMd7J>A-*h%n#0`h1iImC68_v=>25d
zop?Rz25xUScBQvR&43-x9rXq~Zr=+udX6SRg1PQfFAn_iU^IW3dM-qLetFdwhDBME
z8TX^tMpHDRW8<E1qL%qstVai?B{I<q^Ny%RAI)|xN5p{Hli@h%ZF|Fsx3|@y_m&ko
K=~sF;Eq(%qwX%Bv

diff --git a/classifiers/tf_graph_creation_helper.pyc b/classifiers/tf_graph_creation_helper.pyc
index ba8c1623b0ba07ff0da1f5e5eceedd6f69bfefaa..cc7087d41170050991a04ef6986c615efc6ba662 100644
GIT binary patch
delta 16
XcmbQCGed`+`7<w9K=SL2?1EwdEs_Nq

delta 16
XcmbQCGed`+`7<w<^4*sk*#*S_FDV6O

diff --git a/classifiers/train_classifiers.py b/classifiers/train_classifiers.py
index d2edabb..e7d0ba9 100644
--- a/classifiers/train_classifiers.py
+++ b/classifiers/train_classifiers.py
@@ -12,7 +12,7 @@ import attribute_classifiers.eval_atr_classifier as atr_evaluator
 workflow = {
     'train_obj': False,
     'eval_obj': False,
-    'train_atr': True,
+    'train_atr': False,
     'eval_atr': True,
 }
 
@@ -21,6 +21,8 @@ obj_classifier_train_params = {
     'adam_lr': 0.001,
     'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
     'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
+    'mean_image': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier/mean_image.npy',
+#    'mean_image': '',
 }
 
 obj_classifier_eval_params = {
@@ -39,11 +41,13 @@ atr_classifier_train_params = {
     'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
     'obj_model_name': obj_classifier_eval_params['model_name'],
     'obj_global_step': 1,
+    'mean_image': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier/mean_image.npy',
+#    'mean_image': '',
 }
 
 atr_classifier_eval_params = {
     'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier',
-    'model_name': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/atr_classifier',
+    'model_name': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/obj_atr_classifier',
     'global_step': 1,
     'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
     'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
-- 
GitLab