From 99ac4c72adfe64ae6c4f84d87c78946ae232f2c8 Mon Sep 17 00:00:00 2001
From: tgupta6 <tgupta6@illinois.edu>
Date: Thu, 17 Mar 2016 16:10:28 -0500
Subject: [PATCH] Modified code with variable name scoping

---
 classifiers/attribute_classifiers/__init__.py |   0
 .../attribute_classifiers/__init__.pyc        | Bin 0 -> 164 bytes
 .../atr_data_io_helper.py                     |  83 ++++++++++++++++++
 .../atr_data_io_helper.pyc                    | Bin 0 -> 4641 bytes
 .../eval_atr_classifier.py                    |  66 ++++++++++++++
 .../eval_atr_classifier.pyc                   | Bin 0 -> 2804 bytes
 .../train_atr_classifier.py                   |  64 +++++++-------
 .../train_atr_classifier.pyc                  | Bin 0 -> 3616 bytes
 .../object_classifiers/eval_obj_classifier.py |  36 +++++---
 .../eval_obj_classifier.pyc                   | Bin 2671 -> 2699 bytes
 .../object_classifiers/obj_data_io_helper.py  |  13 ++-
 .../object_classifiers/obj_data_io_helper.pyc | Bin 4653 -> 4614 bytes
 .../train_obj_classifier.py                   |   7 +-
 .../train_obj_classifier.pyc                  | Bin 3113 -> 3118 bytes
 classifiers/train_classifiers.py              |  38 ++++++--
 15 files changed, 246 insertions(+), 61 deletions(-)
 create mode 100644 classifiers/attribute_classifiers/__init__.py
 create mode 100644 classifiers/attribute_classifiers/__init__.pyc
 create mode 100644 classifiers/attribute_classifiers/atr_data_io_helper.py
 create mode 100644 classifiers/attribute_classifiers/atr_data_io_helper.pyc
 create mode 100644 classifiers/attribute_classifiers/eval_atr_classifier.py
 create mode 100644 classifiers/attribute_classifiers/eval_atr_classifier.pyc
 create mode 100644 classifiers/attribute_classifiers/train_atr_classifier.pyc

diff --git a/classifiers/attribute_classifiers/__init__.py b/classifiers/attribute_classifiers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/classifiers/attribute_classifiers/__init__.pyc b/classifiers/attribute_classifiers/__init__.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae52301a780f66bda35035e28368b85b5b73608d
GIT binary patch
literal 164
zcmZSn%*&;7_ElIi0~9a<X$K%KW&si@3=F{<AQ3+eAi;n}6#D|j^fU5vQ}s&{^Kug_
z^_}xmQuW<a^TGlhVN`NXVsUY1T4ridv3_DnNl|7}X-R54vS@sKW?p7Ve7s&kWeEq+
TM4R0Fl+v73JCMD_K+FID(HAGQ

literal 0
HcmV?d00001

diff --git a/classifiers/attribute_classifiers/atr_data_io_helper.py b/classifiers/attribute_classifiers/atr_data_io_helper.py
new file mode 100644
index 0000000..0f707a9
--- /dev/null
+++ b/classifiers/attribute_classifiers/atr_data_io_helper.py
@@ -0,0 +1,83 @@
+import json
+import sys
+import os
+import matplotlib.pyplot as plt
+import matplotlib.image as mpimg
+import numpy as np
+import tensorflow as tf
+from scipy import misc
+
+def atr_mini_batch_loader(json_filename, image_dir, mean_image, start_index, batch_size, img_height=100, img_width=100, channels=3):
+
+    with open(json_filename, 'r') as json_file: 
+        json_data = json.load(json_file)
+
+    atr_images = np.empty(shape=[9*batch_size, img_height/3, img_width/3, channels])
+    atr_labels = np.zeros(shape=[9*batch_size, 4])
+
+    for i in range(start_index, start_index + batch_size):
+        image_name = os.path.join(image_dir, str(i) + '.jpg')
+        image = misc.imresize(mpimg.imread(image_name),(img_height, img_width), interp='nearest')
+        crop_shape = np.array([image.shape[0], image.shape[1]])/3
+        selected_anno = [q for q in json_data if q['image_id']==i]
+        grid_config = selected_anno[0]['config']
+        
+        counter = 0;
+        for grid_row in range(0,3):
+            for grid_col in range(0,3):
+                start_row = grid_row*crop_shape[0]
+                start_col = grid_col*crop_shape[1]
+                cropped_image = image[start_row:start_row+crop_shape[0], start_col:start_col+crop_shape[1], :]
+                if np.ndim(mean_image)==0:
+                    atr_images[9*(i-start_index)+counter,:,:,:] = cropped_image/254.0
+                else:
+                    atr_images[9*(i-start_index)+counter,:,:,:] = (cropped_image / 254.0) - mean_image
+                atr_labels[9*(i-start_index)+counter, grid_config[6*grid_row+2*grid_col+1]] = 1
+                counter = counter + 1
+
+    return (atr_images, atr_labels)
+
+def mean_image_batch(json_filename, image_dir, start_index, batch_size, img_height=100, img_width=100, channels=3):
+    batch = atr_mini_batch_loader(json_filename, image_dir, np.empty([]), start_index, batch_size, img_height, img_width, channels)
+    mean_image = np.mean(batch[0], 0)
+    return mean_image
+
+def mean_image(json_filename, image_dir, num_images, batch_size, img_height=100, img_width=100, channels=3):
+    max_iter = np.floor(num_images/batch_size)
+    mean_image = np.zeros([img_height/3, img_width/3, channels])
+    for i in range(max_iter.astype(np.int16)):
+        mean_image = mean_image + mean_image_batch(json_filename, image_dir, 1+i*batch_size, batch_size, img_height, img_width, channels)
+
+    mean_image = mean_image/max_iter
+    return mean_image
+
+
+class html_atr_table_writer():
+    def __init__(self, filename):
+        self.filename = filename
+        self.html_file = open(self.filename, 'w')
+        self.html_file.write("""<!DOCTYPE html>\n<html>\n<body>\n<table border="1" style="width:100%"> \n""")
+    
+    def add_element(self, col_dict):
+        self.html_file.write('    <tr>\n')
+        for key in range(len(col_dict)):
+            self.html_file.write("""    <td>{}</td>\n""".format(col_dict[key]))
+        self.html_file.write('    </tr>\n')
+
+    def image_tag(self, image_path, height, width):
+        return """<img src="{}" alt="IMAGE NOT FOUND!" height={} width={}>""".format(image_path,height,width)
+        
+    def close_file(self):
+        self.html_file.write('</table>\n</body>\n</html>')
+        self.html_file.close()
+
+        
+    
+
+if __name__=="__main__":
+
+    html_writer = html_atr_table_writer('/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier_v_1/trial.html')
+    col_dict={0: 'sam', 1: html_writer.image_tag('something.png',25,25)}
+    html_writer.add_element(col_dict)
+    html_writer.close_file()
+
diff --git a/classifiers/attribute_classifiers/atr_data_io_helper.pyc b/classifiers/attribute_classifiers/atr_data_io_helper.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b958afac0eda1f2f3ff3b1e2e2dc36c210e170a5
GIT binary patch
literal 4641
zcmcgv-E$*H5${>qvLwqNXUkvV5_G;wAyQ!Zptu7fn>h9*A%KtUg_x)+*;;E?(klBQ
z?~bvZ+kK#Xyzl`35`Pb`RPl0Ez^{8Gn*?~^A(nP#dwOT5dw%^fy`{e`)M|hK>nA;$
z{AO`~2Q8~BA`h3Mkf?h@E9$B#-=eNHrAsuLq3-N-Sf<e&bt}_hmAbVlJx|@b&=!RY
z6fIJBk(5^=zGldqrI#h@F3Hev%EDRxH_n`J?vV!p=&SUS=$h!$?h5&<<kjd4i_(hn
zSIC>EYpC3y0GdMV0zLK{G+HBXkvs@lqH73QrpMiN;s+b#uk#y<2b;X8iUow{tregW
z{1d<97N8+LgDwiPUI3IRD52|B&~2g3qaC7Uf3z@>D3>L!E&B616&2EoMi%827F1N}
z1PgrhTP+BBP_sn&3^j6KIYVoC+1%!6Favc|Q6rv-hc`0;*}#%#^K`jv=`!V5LRQ&;
z%%jg<siIDLjz$$4RVi1bYn0bnx>qKiQY?_3rx7MC&}fnJ1%3tvAiNJTa}4ygrHVz7
zy+rvUjbM)|ZB3{eZLtbi5Gof-G{By#kbIwtWy&#i9#iXLmRCIf+@iBrRNSLn%(n&8
zY!$Fpm5Nom{6Jw3ps|?f|3noHy4?91_jOgQQQn{&W`rh<Vx4j@FoHIXVuQ8J!NX-{
zZ2SSBgN@5fVg|6s1)^Vq3Y)S)sAV=&6^;L)@=f-fwH!VSSnCY#W+PvvjhnasljzS-
zC+^VnRf{(N@7J-BeE#bqbHId+HK+mS#fBRjn6|M6J9VHD(}3{oM4g-+p#7dsAQr3O
z`SaE?eDWCHN0e#~^la#k{h$g;{5P0k9(OyV@c=deKh}OahVKx?zMJ})4nR^M5y0*q
zMDD<M0#9c5lDHoXV0`@klDreh3J3v`;E{+)#(pfsC`$|xCa$M>Qal!={AjGNg!<A?
zlT4`8jR(F=0)rJCyLxDrNuVMJU#3$VH_}Ecs2Y6f%gfO?h$g#<QXj&2>u#F5S2l~4
zWnwRgS{%BF03z1S=?9@7yOFO2AK9iCqzD~^z3;}3>>OdkqTN(GLG1aL^7z!%y`dxW
zc{ZdwLq8Y{^<=4ma{<7HqG@mF#<3q}^5{nH?Z|VrD~q~1H6pW{Vd$Rnn{10fn8p~Q
zH1E-9$SU;GWbBBwW%iluGV?>fr+v?X28nEBkOrP(j3_4WCFg8^QB<C#$;IuUmxQ9H
z**-sFq|At)F_J8J3=yJA3k;GjI|0=mCXwIPZXCH+?Zd?L+wc4FCm$V5Z@tjXvY;RM
zY1VeNPJ`2P?K@xPrVaz-1PKrh$9}pqzS3)0o6(AbIB<+D*vI^ok;uM>j;s~6X1%5E
zTd%2atMB0dUA1VH)jIx*`2R?)skg0;+EwdTkA0sv#!Zs?5eSpiIh9H-hs$%uM1GO9
zYf|TRg^D?u048#%qAfH|w-rwQ&s*$$#{oKOu}|B(+_922dzJVlds>Sn+Z)_%a`zf{
z-{S5*cdv8z20HO;*{{u>d`D<=PPvT-gEx!LIQU7u1xN)#-WNB4@J}EB;b@B48qi~`
zzygM-0u-_>&N#l6X;5Nxf!gJVoS%UuY7L${%Y0l%pY8HN3Q+m*!`Xv_2<w1TLVo;;
zP3DjvD|Gft<V9Fv4yV<`7@R|nO&a8YYO&eH7y@fQOp?^DGmqU`vYnghtFcKj$U%=D
z+dt&@#C|u%lCX@=qe<xgSF=^{TI60j0YXhIXFueH(7yR#5qaN-q>T4QYE?B65j$2*
zNf-@v6gnJB+C2?@=OSetdpwhWjMKYl5>q~^M5zp4g21_|xG3wQxUz(RO-n+|lKouV
zy~pm)jgdd(WqyDTrh@?;KrF;)S~)QjCKU%6$*Qos@(5wX@0(NYB6|wD^Wf;&;m4o;
z`e~EZ->-Ehx6{PC!brB#JWWz$guU%Y+fA4;^!K(UC;#-(?(Uo0`^}oDD7kKW4%qAw
zlHvpvW{YbGC=$yOvB;`@`)y3U8v}}p4i29{JI;?Gb`p10q%c94(c!~D5dItlV~7j3
zlGss~PY5R~$CZpFoE}nPz%~;t4mh@qMO!A~)Y(Hs(9KE`pvD@U$&NElSxx*pI^C~j
zoFmN0+ka7X+S&t?14Wo^p3`C-?GAT)+)0$N?l^=c+WJWvxmt{E?}INJ$1&r0K~GBs
z@P&VMN0@VZc0JF*0UAL)RygA*Cbg`#)C1LH5T;FlUj+kr19$@?auPl~CPzTVcI%)_
zXl7||Z~H~D-E>2}xBc^99lZavdHn3-=6lb6bA0q*yJ^aZy%$AOycNX0e2mx!Siu}3
zf(>J9iUBE<I2BJS5QbUay(6gD5sW6<9S9U0@}zp4Jh<^t_zytbdX(|a$(X-xCc)0r
zk25|8;YOU45Yg>PWVa>Ijf>oPlEi2)Ofugf#rWLO-q{KwS!&7mFp}b$y`;q^nZz#F
zYmOr)h2untcaBW!*gwMr`-r;_xcd;@t=d9;lSP#M6UcjlCVpL3)!AyfTAC;vxk2nW
z*}HiBKWn(Bmt)8Fv-41A?E@Sjhd0HX^E>AeLOyWA9k!P#-rY#*pp3&lLs_VYK|I(Q
z$Af@N8Ts5WyXT=@L1*5(%IrlR+(U=-Fb)$P2B)Y3nUcIQAxoGr4z-qY<#wXDru`-!
z+8m<dEI6*mSLWzv3u-^kk`x8Ng?xbMzD&Y*KPKVvq-^7}-yGhIoc#-QQc}xVVa`os
uNJh=r^^Q#{#_JgW+ULm2n&|3^|4aDuec4)hV`X-wwz^)o>gu6dQ2zqHe(m)D

literal 0
HcmV?d00001

diff --git a/classifiers/attribute_classifiers/eval_atr_classifier.py b/classifiers/attribute_classifiers/eval_atr_classifier.py
new file mode 100644
index 0000000..c3222b1
--- /dev/null
+++ b/classifiers/attribute_classifiers/eval_atr_classifier.py
@@ -0,0 +1,66 @@
+import sys
+import os
+import matplotlib.pyplot as plt
+import matplotlib.image as mpimg
+import numpy as np
+from scipy import misc
+import tensorflow as tf
+import atr_data_io_helper as atr_data_loader 
+import tf_graph_creation_helper as graph_creator
+import plot_helper as plotter
+
+def eval(eval_params):
+    sess=tf.InteractiveSession()
+    
+    x, y, keep_prob = graph_creator.placeholder_inputs()
+    _ = graph_creator.obj_comp_graph(x, 1.0)
+    obj_feat = tf.get_collection('obj_feat', scope='obj/conv2')
+    y_pred = graph_creator.atr_comp_graph(x, keep_prob, obj_feat[0])
+    accuracy = graph_creator.evaluation(y, y_pred)
+
+    saver = tf.train.Saver()
+    saver.restore(sess, eval_params['model_name'] + '-' + str(eval_params['global_step']))
+
+    mean_image = np.load(os.path.join(eval_params['out_dir'], 'mean_image.npy'))
+    test_json_filename = eval_params['test_json']
+    image_dir = eval_params['image_dir']
+    html_dir = eval_params['html_dir']
+    if not os.path.exists(html_dir):
+        os.mkdir(html_dir)
+    html_writer = atr_data_loader.html_atr_table_writer(os.path.join(html_dir, 'index.html'))
+    col_dict = {
+        0: 'Grount Truth',
+        1: 'Prediction',
+        2: 'Image'}
+    html_writer.add_element(col_dict)
+    color_dict = {
+        0: 'red', # blanks are treated as red
+        1: 'green',
+        2: 'blue'}
+    
+    batch_size = 100
+    correct = 0
+    for i in range(50):  
+        test_batch = atr_data_loader.atr_mini_batch_loader(test_json_filename, image_dir, mean_image, 10000+i*batch_size, batch_size, 75, 75)
+        feed_dict_test={x: test_batch[0], y: test_batch[1], keep_prob: 1.0}
+        result = sess.run([accuracy, y_pred], feed_dict=feed_dict_test)
+        correct = correct + result[0]*batch_size
+        print(correct)
+
+        for row in range(batch_size*9):
+            gt_id = np.argmax(test_batch[1][row,:])
+            pred_id = np.argmax(result[1][row, :])
+            if not gt_id==pred_id:
+                img_filename = os.path.join(html_dir,'{}_{}.png'.format(i,row))
+                misc.imsave(img_filename, test_batch[0][row,:,:,:] + mean_image)
+                col_dict = {
+                    0: color_dict[gt_id],
+                    1: color_dict[pred_id],
+                    2: html_writer.image_tag('{}_{}.png'.format(i,row), 25, 25)}
+                html_writer.add_element(col_dict)
+
+    html_writer.close_file()
+    print('Test Accuracy: {}'.format(correct / 5000))
+
+    sess.close()
+    tf.reset_default_graph()
diff --git a/classifiers/attribute_classifiers/eval_atr_classifier.pyc b/classifiers/attribute_classifiers/eval_atr_classifier.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0be14f65263aa832ead7df9c9f5988c223c072a2
GIT binary patch
literal 2804
zcmcgtTW=dx5T3J+oy2i&c2dVpFDVfa9!R7<<e@@9Riv~+svxaeWC5%7?y+;+y><6E
zP3Z1Zio_f5`~rR!&paXV1K^uk+bxJ^y!P3dIdePT%pBMLTy5_E_WLuRrXL6I&+*HD
z#uDJGC?x97X+?cClRGr7QGa2^)@j_J{^E>XqW-eb4uwsMR;a&1YU;5?{neS?8uiy_
z^ak}eXLOtT9SUmn2I03jk0%EU$jSVL&*K9H(lQ^J-1nK=#jlCq1N`!5HEa`&6<raP
zimc=c`r8_n4wYCJM7ThO!_KcDtcyvV0!V5UEYOuwWkb{&tX4p|K|!5*SFm4{jYTSI
zR4j`9lGrcNcv+$@Q?X3rCKZYb;MbsDut;6IBHR+3DC#1tvMm~~l3k<mI*m7EcNw}(
z_6eO59$>ygy|OZGfoh3t71Z2vlZsW=&zp*UwWq9#fkcPKTkk{w_G>fmYh-X%RG?1B
z!3G7Jc(p0$&~dp=1srWtvCe)K>!A>=w_WL7tMp*ir4og!Q`x4AUzNqi4%>Y3-7l)_
zNc=7pNFRAODd<trp<<J^x>$ni{JmPFa*I<bBt#EzA2Yd0Jpg=+3IYA&63HkU?@+PC
zppQ_fT`IN_N|o2Bz?l&l^%3951(K787adNBBi*1aNebFsV%U+$G)7n2j_5~+hK|&x
z+?lDsC#UfrDzHZ`J$m6#-gDlT5CGq#@pU<7SCB?~T_?Ls;~P{!K0V|+E_pH}{G1pr
zAZWAd5e5dWv4g$<{!I2jqE=j?oBxg$+#-C@r2kdMmpupH=P%!==WGl;!d%{#bW!*{
z+6G1^4;km%vd43@S7{N3@DAqd@dyh(e|%<_@Ja?_cc?vU>tJ#}Np-$liI4mwKEG?g
zkmccwl0b)U>_yu0%zlh%AR2|qzzf~n>eMcRPbSt4OqQ>)f%an8MBYdr#_6SHw^rxY
z9p_0bob0i>WO8PsP;5Dvi35Fc$fCRj_9ROtvE6%`P3)N@^Hrt;<6D!&My?ai{t<^V
zd<$A!Ap{LM)F{(BmiuN9PP9Rd$lT>+<3lW8vO1?W)H*<T_+(B%cwM@$%fmDtnI4Na
zu{^~A_U`-sB=h{sdwZ`-^HYaNfW<upS3}#z>rreqG^{z-PjsFmO}h&5DD%=Y*Ux}=
zl8Is(dcHnO!a!%PiPMSA(TP=6{3J?Uu@KQnTh~v*P|K-B<k<|OxfbK5&%JQsF)ng;
zn|Y?9o_OavlM-b(LXv60D7TrU9j5|Nn0SHkNiKZq0qzC39Vc+pVE-2;M~)14^nxq5
z4vO;#JnOk^k0{%aNi})HdV^5AuQCHfq-DGyaCN976h-`GUW~J`4XZ~cHtxW){@Ki%
z)DqR_@Lo2GybB?RNfvqb6hac-L^+Q5PG4DXB(36yNv_?Y3AKb17Pl%lNDfQ`J@h7_
zt&l$8e<I0}fpk+Z^P&oh9A%WgzmT<R=NDS1Zki<nk-4YRik#;qq-Z}LFmKu{5}Z@a
zs%e@JhxEaGFeFhPvQsubm7(E56FjFV@(x--y0toD5+*1CnpfBgzTEtvrQm*&W#}5o
z(TIXX6D{eZdF#8pA!q^{fKts-@=9_aS=R*Okk7_6Po{}R^OO(3lPvi+ppMRxNFP}*
zj=amG2e>_sPIUb2>-)2pA9@&#q0w1>gbOG$gNfDdyHY8FbopHM5H|^TByVy2{Cg~<
zTF$m=t8LX&ZD(D5;Jl}r^0d{4x}|Dr5l>UC&M2p$8mgl{Qu}Hbc6Fx<YG1X~9d!>n
zD|p^lJ5Cecn(Baivl^Gn{PY$xd=;Sk(lD{186YR71h^^5O`8J1jejc2jVEZnrJ@uV
ziY%T)>E-DS4p8031}wPh{<S*JlWZ6!uTI&kw9S0-OyZvDFxA=VKJV@DGOT*Kx^HkJ
z%ry9dk{;)zr^2ePd<eLB2!*$`{2vOsgn*ym|78;DPkDy&Ei5gRV;xTmC1e_p^Ec}`
BcfJ4s

literal 0
HcmV?d00001

diff --git a/classifiers/attribute_classifiers/train_atr_classifier.py b/classifiers/attribute_classifiers/train_atr_classifier.py
index dd490ad..5145bae 100644
--- a/classifiers/attribute_classifiers/train_atr_classifier.py
+++ b/classifiers/attribute_classifiers/train_atr_classifier.py
@@ -5,36 +5,33 @@ import matplotlib.image as mpimg
 import numpy as np
 import tensorflow as tf
 import atr_data_io_helper as atr_data_loader
+import tf_graph_creation_helper as graph_creator
+import plot_helper as plotter
 
-def train():
-    # Start session
+def train(train_params):
     sess = tf.InteractiveSession()
 
-    x, y, keep_prob = placeholder_inputs()
-    y_pred = comp_graph_v_2(x, y, keep_prob)
-
-    # Specify loss
-    cross_entropy = -tf.reduce_sum(y*tf.log(y_pred))
-    
-    # Specify training method
-    train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
-
-    # Evaluator
-    accuracy = evaluation(y, y_pred)
+    x, y, keep_prob = graph_creator.placeholder_inputs()
+    _ = graph_creator.obj_comp_graph(x, 1.0)
+    obj_feat = tf.get_collection('obj_feat', scope='obj/conv2')
+    y_pred = graph_creator.atr_comp_graph(x, keep_prob, obj_feat[0])
+    cross_entropy = graph_creator.loss(y, y_pred)
+    vars_to_opt = tf.get_collection(tf.GraphKeys.VARIABLES, scope='atr')
+    all_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
+    print('Variables that are being optimized: ' + ' '.join([var.name for var in vars_to_opt]))
+    print('All variables: ' + ' '.join([var.name for var in all_vars]))
+    train_step = tf.train.AdamOptimizer(train_params['adam_lr']).minimize(cross_entropy, var_list=vars_to_opt)
+    accuracy = graph_creator.evaluation(y, y_pred)
     
-    # Merge summaries and write them to ~/Code/Tensorflow_Exp/logDir
-    merged = tf.merge_all_summaries()
-
-    # Output dir
-    outdir = '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier_v_1/'
+    outdir = train_params['out_dir']
     if not os.path.exists(outdir):
         os.mkdir(outdir)
 
     # Training Data
     img_width = 75
     img_height = 75
-    train_json_filename = '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json'
-    image_dir = '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images'
+    train_json_filename = train_params['train_json']
+    image_dir = train_params['image_dir']
     mean_image = atr_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width)
     np.save(os.path.join(outdir, 'mean_image.npy'), mean_image)
 
@@ -44,12 +41,17 @@ def train():
     
     # Session Saver
     saver = tf.train.Saver()
-
+    
     # Start Training
     sess.run(tf.initialize_all_variables())
-    batch_size = 100
-    max_epoch = 10
-    max_iter = 95
+
+    # Restore obj network parameters
+    saver.restore(sess, train_params['obj_model_name'] + '-' + str(train_params['obj_global_step']))
+
+
+    batch_size = 10
+    max_epoch = 2
+    max_iter = 950
     val_acc_array_iter = np.empty([max_iter*max_epoch])
     val_acc_array_epoch = np.zeros([max_epoch])
     train_acc_array_epoch = np.zeros([max_epoch])
@@ -57,24 +59,20 @@ def train():
         for i in range(max_iter):
             train_batch = atr_data_loader.atr_mini_batch_loader(train_json_filename, image_dir, mean_image, 1+i*batch_size, batch_size, img_height, img_width)
             feed_dict_train={x: train_batch[0], y: train_batch[1], keep_prob: 0.5}
-
             _, current_train_batch_acc = sess.run([train_step, accuracy], feed_dict=feed_dict_train)
-
             train_acc_array_epoch[epoch] =  train_acc_array_epoch[epoch] + current_train_batch_acc
             val_acc_array_iter[i+epoch*max_iter] = accuracy.eval(feed_dict_val)
+            plotter.plot_accuracy(np.arange(0,i+1+epoch*max_iter)+1, val_acc_array_iter[0:i+1+epoch*max_iter], xlim=[1, max_epoch*max_iter], ylim=[0, 1.], savePath=os.path.join(outdir,'valAcc_vs_iter.pdf'))
             print('Step: {}  Val Accuracy: {}'.format(i+1+epoch*max_iter, val_acc_array_iter[i+epoch*max_iter]))
-            plot_accuracy(np.arange(0,i+1+epoch*max_iter)+1, val_acc_array_iter[0:i+1+epoch*max_iter], xlim=[1, max_epoch*max_iter], ylim=[0, 1.], savePath=os.path.join(outdir,'valAcc_vs_iter.pdf'))
-            
-            
+        
         train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] / max_iter 
         val_acc_array_epoch[epoch] = val_acc_array_iter[i+epoch*max_iter]
-        
-        plot_accuracies(xdata=np.arange(0,epoch+1)+1, ydata_train=train_acc_array_epoch[0:epoch+1], ydata_val=val_acc_array_epoch[0:epoch+1], xlim=[1, max_epoch], ylim=[0, 1.], savePath=os.path.join(outdir,'acc_vs_epoch.pdf'))
-
-        save_path = saver.save(sess, os.path.join(outdir,'obj_classifier_{}.ckpt'.format(epoch)))
+        plotter.plot_accuracies(xdata=np.arange(0,epoch+1)+1, ydata_train=train_acc_array_epoch[0:epoch+1], ydata_val=val_acc_array_epoch[0:epoch+1], xlim=[1, max_epoch], ylim=[0, 1.], savePath=os.path.join(outdir,'acc_vs_epoch.pdf'))
+        save_path = saver.save(sess, os.path.join(outdir,'atr_classifier'), global_step=epoch)
             
     
     sess.close()
+    tf.reset_default_graph()
 
 if __name__=='__main__':
     train()
diff --git a/classifiers/attribute_classifiers/train_atr_classifier.pyc b/classifiers/attribute_classifiers/train_atr_classifier.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78a4e66e7a05fc26d316c57a52f11a19907ab03e
GIT binary patch
literal 3616
zcmcgtOK&5`5w0dBN}{O8QV)vKYL~KhH`*9p0Xb|CV-ZKT1KDuk#1fz&gzZ6#(<FyG
zubpnmqBC<+{y+`^@*9$4fc%8~iCl6FkV^pnfaI$hQV%<TZ{ee-y1JfURrOTAS*ri(
z;lIv$H2GBU{2BiC4_JKs6oo|H8(L9UP5BiXRjE5O-OkdeM%}sTcAmNmQ@T#w#VOsO
z?$VTArtZp=UZw5>@+<T^xNDMGBfm<oE0q17Op~sOR4mb)9?T%J;6M0$+EXC+@n6ON
z1N`mZaOj9eiZu#FPGSW8d6kL^6<BA4IYW7cle-3UR#ax$IehvvbgfiT6R{eL<q)os
zKTGXv=+DW)9OYHY=R|*gqR)=!X@sm{r6yJuzI}A9sHl?<lptBB0a3m{AAYey#Uka4
zG+LxDEA-|sP^Q-LM?P%b$ko3q_xFh$dv4HZiE?06N1TdA%aqTP5336jNkhV3xf6CN
z-Jo`{Om3A%4`|e+e3?z!GsT>>$yT6Wq5`$Tx+-3lq4z+j6;Mrb>onS+9IPh!tK@@#
z;0j&Ws8}PnNu$=(OP$U7C@P>`<5jWp)<hXBAHVgf{u~wS6aegXDFN#7lF1EfqXI8c
zLX_uau^~VMgSJ$DlX5_?EzBL}ASsn2h=+W@C~Z>PZ&Azb%5F<AYjG@on~E0s+f;0k
zze7b^f`3GLoANDMAE0MYLzsrO$IO~=cD}=TNLQVT>#z(|KK!CtDR$}VSTPG>GV3R*
zcqGu>%^onqU5nORSSBa}g~x)z<69I!PEkNwi~<Zx+IfTWB^scgkzNO`)aTOY#UAAl
z-34}g)GEnWkQ{-(!4c)#;u~6RxPRr|qS3zO^@QCbEO=}|Uygy6Y`6c%*`6=5NS#GF
z5(ymTV!Cs`752CZt}9jL!{5_yySX1z8=}V_z4>$H*V9bev@y-(Ui9sJhf|d#o{@hy
zk6j6T62Je1LE-dD_MVRrcJ48zxqMF^&m<43NvEN7xPQvIDm(rMeC!{vpX-V$-l8{u
zL>4F*4sgHm5FZGId-wF<IqtJhaX;|q-|xA3Jd*Q~*Eia^S!`@CNsV17*@wL(zWh;u
z4Y?{7?K1lW&oi9``aCqY<A&OGbY?o|CWr@}By~X)d~N)9JC1MBj$Ofe6o#G4sR~Z!
z*p#m$FU%wWoOxkjog5?=&hvvT_yj_(&P%3)*c;g-b{tj^>46c6e<Gv+Ns(Q~5*ZzP
za`+-n$HDGX9Dais0*=G61m_tgz6rfpM@Hg%$8mgoFbI=#9eUQ8G+?g*zp{eA!o>uv
zf4nE|`^Nb2-g6G`1iA`?NQf~L45rnW=2zMcZ4;YI9Uk?1-lg>dXR;Tm-?!UfpTfzz
zo!=In&Y2E7khsWnZ_FGi!IiI_c8d@kGvx^Y;=J6$s^wRcE=`*>=?&SoB#9Dvy--^l
z^aGQLl{=|P`SJDwB7MgZejFn(-E+aEIW<rM$V*@kGMx^+UWTqnG7(He-7~`^^iAdk
zaeCqG4W_*$N<GmK*1$NgmxQ5#RhVPsC8%2|MlMXOm7I>*)hA|L?#_-*K0bQ?7e7B0
zR7Xh=3n2yL_AIgB5RcG6znrvN*+x+iOY=$B7@%~aB?3-2v0^7hWn2w=bA^7h0(bNp
z?IwM|VG815qLHWAxos<fEl$M_SJ`dA4y&+V?_9gy(3>c7q^A(cWNNaD*lmFWF&F3%
z<KgKr^lm0X+CDQDP#aOOE|bVil)AA{Uz-dWu-Qz<0GJWv3`I2!6Nf5F;h6T&7}8Oa
zsV+4*>3)($7;|p<yK(_sv*{jcV%!=i1cM^_rmrtTSC;l*AAHx~MdmSu-KYfBqH$!J
zu4ElcSzjBIdTExN3(q@&B1t@BaE&j)_K=kIOzg5G9ZRbM6YIIe!#yI?G{MOVEI7P#
zA@HP@GIz7Z?ZMS7FN%TpRp6uGGSAV#8=7D+bdurirQ-F2&~Sg9@c5qmkmF=^pj#GI
zLcJM#lr0prxCP31%F<iZRTLwA<;hG+biCu?mmy-CJheK@^tfa*-1pg~E;m;h<gG$^
z!lf!HzS#o7t(25Nq`bG`%86@15<QUt>+(98^eb#j%$$tJ!*KvXlEwZ8)Zs9R%%Rh9
zq{oLJ;O;y;Hu2eKN7KjcO=AzySy^y?;mn=;#~w~{xOXePNXI98u+CSU{SphQMrA`a
z)rM+g-NEw!->0fssbjxdt*M%-<NHWGuI{P*>PmI3`b@YrwF!x)TEq9b+E;s`pQ?8%
zO`KTucD1gg4xShBYs5QPhr4ph#`feXZysR5T<}PQ!8wAHf#QLgko@XPLx;ESJwbWd
zNH5${Q5r-8;l&pb(tH~#Wz(A2CRsmBzB-Z8^5uJT-vkMIGfYi(@|@-QvT*%!?#fFT
z7wJ@jODQ#%5uT|17)*h5k}LN!as{_6=KrCh;`tTcuNR^DDR-vb#L`d=wSq4{<y%>J
HzoGsM4fITj

literal 0
HcmV?d00001

diff --git a/classifiers/object_classifiers/eval_obj_classifier.py b/classifiers/object_classifiers/eval_obj_classifier.py
index 6c78aa4..e9320c4 100644
--- a/classifiers/object_classifiers/eval_obj_classifier.py
+++ b/classifiers/object_classifiers/eval_obj_classifier.py
@@ -5,7 +5,7 @@ import matplotlib.image as mpimg
 import numpy as np
 from scipy import misc
 import tensorflow as tf
-import object_classifiers.obj_data_io_helper as shape_data_loader
+import obj_data_io_helper as shape_data_loader
 import tf_graph_creation_helper as graph_creator
 import plot_helper as plotter
 
@@ -25,32 +25,42 @@ def eval(eval_params):
     if not os.path.exists(html_dir):
         os.mkdir(html_dir)
     html_writer = shape_data_loader.html_obj_table_writer(os.path.join(html_dir, 'index.html'))
-    col_dict = {0: 'Grount Truth',
-                1: 'Prediction',
-                2: 'Image'}
+    col_dict = {
+        0: 'Grount Truth',
+        1: 'Prediction',
+        2: 'Image'}
     html_writer.add_element(col_dict)
-    shape_dict = {0: 'blank',
-                  1: 'rectangle',
-                  2: 'triangle',
-                  3: 'circle'}
+    shape_dict = {
+        0: 'blank',
+        1: 'rectangle',
+        2: 'triangle',
+        3: 'circle'}
+
     batch_size = 100
     correct = 0
-    for i in range(1):
+    for i in range(50):
         test_batch = shape_data_loader.obj_mini_batch_loader(test_json_filename, image_dir, mean_image, 10000 + i * batch_size, batch_size, 75, 75)
+        #print(test_batch[0].dtype)
+#        print([np.amax(test_batch[0][0,:, :, :]), np.amax(mean_image)])
         feed_dict_test = {x: test_batch[0], y: test_batch[1], keep_prob: 1.0}
         result = sess.run([accuracy, y_pred], feed_dict=feed_dict_test)
         correct = correct + result[0] * batch_size
         print correct
+
         for row in range(batch_size * 9):
             gt_id = np.argmax(test_batch[1][row, :])
             pred_id = np.argmax(result[1][row, :])
             if not gt_id == pred_id:
                 img_filename = os.path.join(html_dir, '{}_{}.png'.format(i, row))
-                misc.imsave(img_filename, test_batch[0][row, :, :, :])
-                col_dict = {0: shape_dict[gt_id],
-                 1: shape_dict[pred_id],
-                 2: html_writer.image_tag('{}_{}.png'.format(i, row), 25, 25)}
+                misc.imsave(img_filename, test_batch[0][row, :, :, :] + mean_image)
+                col_dict = {
+                    0: shape_dict[gt_id],
+                    1: shape_dict[pred_id],
+                    2: html_writer.image_tag('{}_{}.png'.format(i, row), 25, 25)}
                 html_writer.add_element(col_dict)
 
     html_writer.close_file()
     print 'Test Accuracy: {}'.format(correct / 5000)
+
+    sess.close()
+    tf.reset_default_graph()
diff --git a/classifiers/object_classifiers/eval_obj_classifier.pyc b/classifiers/object_classifiers/eval_obj_classifier.pyc
index c4e6f97b11143b781aef4f2a8c5d37ac1afe9cfa..48898ca0a27747f6356b9311b4ebe8a2e4f3f745 100644
GIT binary patch
delta 426
zcmXw#!7D^j6vn@E-^?>JM$EkTW`@BqLy<xfWrarUNqUiR$LKXhJi9weV`ss1v$3&K
z%1ZMGD1U>Evapnm1<BHx!9C~P`|EVR?{@FIed?r*kCw>G*W2<I{+8f#n(N@Kn)Y&T
zD<mKZr8c;+e~!boAPt-Vu3fzyrIEYmB^tjW--b+}LddUOQaJ`c2G4;CLq(uMkV&W*
zR4eQxZ!*QHrReQwhYDb#Y(Wvc1e6ICL{UI?=qr)6th#SqkcQpxCen{$lREXJF&ALp
zq*g|s)d?@v@OMyQZ5OAm8N)-}jeZ+s1`f;T7wWfkp#qQ(c-;-+S1u0p=x^hb#`J4o
zg)+67;4x+WApgZyY4@PuTTJ@}hILuk-Z(7z>&5+zy`9>Fc}=vccS4z2gPJ%Z)Iz_+
zEKD-VqCq1^X^K-&Bq_%crCyrisb-FTN{I-?xdt44AG**-VVCmyDSU25c`zz|0Mt`c
AP5=M^

delta 437
zcmXv~K}!Nr5dP+EyXmHFx~^+&7G{cFq98g)6m_nMP(sw*%0(#B<9Z0Yb*e`}(6OLv
z*dOWGPw>>Cu0^v_Z)V<ne9ZUp=3QAIDM$H?>;7P{-@xAy+}F6;H`k6|^bM3J-~uk8
zl|VW9k4=<Qa1~qvW$XHRiR_p6Jz9CB4%2b;2pflY)J>q1fNvp)BG3^;;HD8I5Ew|=
ze7QEKtcl!I2EhncyCwpH&odeV6~_YZq<j!_rV;vvf+?hi-rtaKqE3!nQ8r~)nKvEy
z?7`L+dMumfl&{L7ogMPL!wf7|Zf)~2yQNtfANi7+x<}c5MZKW@kM=|dJK-NtOjwkp
zq%f#RbEFR=NfR`~Xp}2H>><)POj@8kIjm60W|XHG<!F`0X^aY#5jvI9QjOj&6xjQ;
z?t6_+wb81z+s&h<civuPs_ND{wQBRUdg8UtymNUUtwh&(x)eS~E=N3B*Ylde#xd~=
D|BF&*

diff --git a/classifiers/object_classifiers/obj_data_io_helper.py b/classifiers/object_classifiers/obj_data_io_helper.py
index 5a68329..1e9b435 100644
--- a/classifiers/object_classifiers/obj_data_io_helper.py
+++ b/classifiers/object_classifiers/obj_data_io_helper.py
@@ -1,4 +1,3 @@
-#Embedded file name: /home/tanmay/Code/GenVQA/GenVQA/classifiers/object_classifiers/obj_data_io_helper.py
 import json
 import sys
 import os
@@ -9,19 +8,20 @@ import tensorflow as tf
 from scipy import misc
 
 def obj_mini_batch_loader(json_filename, image_dir, mean_image, start_index, batch_size, img_height = 100, img_width = 100, channels = 3):
+
     with open(json_filename, 'r') as json_file:
         json_data = json.load(json_file)
-    obj_images = np.empty(shape=[9 * batch_size,
-     img_height / 3,
-     img_width / 3,
-     channels])
+
+    obj_images = np.empty(shape=[9 * batch_size, img_height / 3, img_width / 3, channels])
     obj_labels = np.zeros(shape=[9 * batch_size, 4])
+
     for i in range(start_index, start_index + batch_size):
         image_name = os.path.join(image_dir, str(i) + '.jpg')
         image = misc.imresize(mpimg.imread(image_name), (img_height, img_width), interp='nearest')
         crop_shape = np.array([image.shape[0], image.shape[1]]) / 3
         selected_anno = [ q for q in json_data if q['image_id'] == i ]
         grid_config = selected_anno[0]['config']
+
         counter = 0
         for grid_row in range(0, 3):
             for grid_col in range(0, 3):
@@ -31,7 +31,7 @@ def obj_mini_batch_loader(json_filename, image_dir, mean_image, start_index, bat
                 if np.ndim(mean_image) == 0:
                     obj_images[9 * (i - start_index) + counter, :, :, :] = cropped_image / 254.0
                 else:
-                    obj_images[9 * (i - start_index) + counter, :, :, :] = (cropped_image - mean_image) / 254
+                    obj_images[9 * (i - start_index) + counter, :, :, :] = (cropped_image / 254.0) - mean_image
                 obj_labels[9 * (i - start_index) + counter, grid_config[6 * grid_row + 2 * grid_col]] = 1
                 counter = counter + 1
 
@@ -51,7 +51,6 @@ def mean_image(json_filename, image_dir, num_images, batch_size, img_height = 10
         mean_image = mean_image + mean_image_batch(json_filename, image_dir, 1 + i * batch_size, batch_size, img_height, img_width, channels)
 
     mean_image = mean_image / max_iter
-    tmp_mean_image = mean_image * 254
     return mean_image
 
 
diff --git a/classifiers/object_classifiers/obj_data_io_helper.pyc b/classifiers/object_classifiers/obj_data_io_helper.pyc
index 46ba9ab54418e0b517e20fe9253b2dbea0b4aa99..9bdb6c7015cc3a0da7906325045969549a720070 100644
GIT binary patch
delta 443
zcmXYsIZH!95XX1+>ATM)1{07R1Wm+)6bfRDCqyDxJnN7mibf+s-UFN1SlB3{jD?Mz
zt%AH>{0M#nTaQw(cGk!+znS^(e|Fi2!dXF)zAfGR`Fv6ulDciNp&n<$I4PC!PD-gZ
z>m^uX+rejP0$lu(j-uwkJs*KFD1psoNK)7oGcW}6yaE%D<WW{BL%zfq3uFR=aWb_O
z>h%9yr5apfpkCQqMTv&ZwqP4Dt5rb9hE~HHua)*wq)rR7Nnj9%EM#Dc>k!TA3jp(Y
z?H%Bop;Gs5yvP+Y`3v}H=AeqN=DD{-LZA{DxR8Pz$BxxCmT4{y5d12BShrBaQ+q$%
zUnVsiB(yq5!A=qSrfY8(4;>ro*mI`bMIo*UtP9lf!aGR#bJUCcdrs!spr%A~ZVASP
z=x-20gRmDKcACfCcDU2+udj&XzuUw`C!=|qO%UNwi0blKCML)9cv4I{{lkJ^pFm5}

delta 482
zcmX9)IZFdU6rQ(}>}EH6<uG1|SHP%22*F095f8+(h?RvbV%B6eiIQEgiKR9+%4=*a
z#mYj`tF?vTPq4PM7vCfd@8iw)eeamLPCus={%dIYkM9@dfXgk3^_5MW=6QVJ{Sk|`
zk}O%a@s(#GhClo)^v4)$?}r|xA!!08!IWf3(J?8Tg&vrf3Xp|Vs6D9UU@S{!UnXEM
zMW!Q#vM7Z{u|nbBV<U1qROT#ssj*gs?PWql?|Pij0|9@is|ASShPs{zOk#rixl?x!
z8;(~$thjiqR(o?q`$A)ofDssm8A$_;w+(FppnzXW#fXsTI9|gG9H)qH(RX=)n9?F%
z>7!7@C;cY2N`x<b3~cC!QNR_W8C|A(XQ@&V-;I0N#C`K*pj%8*5GFXPo%wD?tV<@E
zW{LN8NaxtAxjVRL&DcvqUK6k(U<>c8T4IjUMT*W`8Pu>Uw{3e8ri4)1CJ4lTRfBrX
nxoCQU+w4}0&+9f5eNOQ+7Kai1jwP+JcsOn(Qo5wWI2iB`wWeEq

diff --git a/classifiers/object_classifiers/train_obj_classifier.py b/classifiers/object_classifiers/train_obj_classifier.py
index 292e1ca..45d5830 100644
--- a/classifiers/object_classifiers/train_obj_classifier.py
+++ b/classifiers/object_classifiers/train_obj_classifier.py
@@ -21,21 +21,26 @@ def train(train_params):
     if not os.path.exists(outdir):
         os.mkdir(outdir)
 
+    # Training Data
     img_width = 75
     img_height = 75
     train_json_filename = train_params['train_json']
     image_dir = train_params['image_dir']
     mean_image = shape_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width)
     np.save(os.path.join(outdir, 'mean_image.npy'), mean_image)
+
+    # Val Data
     val_batch = shape_data_loader.obj_mini_batch_loader(train_json_filename, image_dir, mean_image, 9501, 499, img_height, img_width)
     feed_dict_val = {x: val_batch[0], y: val_batch[1], keep_prob: 1.0}
     
+    # Session Saver
     saver = tf.train.Saver()
     
+    # Start Training
     sess.run(tf.initialize_all_variables())
     batch_size = 10
     max_epoch = 2
-    max_iter = 1
+    max_iter = 950
     val_acc_array_iter = np.empty([max_iter * max_epoch])
     val_acc_array_epoch = np.zeros([max_epoch])
     train_acc_array_epoch = np.zeros([max_epoch])
diff --git a/classifiers/object_classifiers/train_obj_classifier.pyc b/classifiers/object_classifiers/train_obj_classifier.pyc
index d803fd2f3ff49fa7edbd7c3cee464cb231ef703c..eac87e3278c2ff50ee06780b0e6328693ac40b63 100644
GIT binary patch
delta 273
zcmZ1}u}*@W`7<w<`L|aa*`1ge1txnjEte8xs9|9+tCeP`VPwb>WvF3isNrBpXJSYZ
zn#|2yX(r5&!oyG_%upl3AOaE<VMyU-NZ|#m5M?OgV#wlV$P!|x;bce=V`ydqB4)-E
z@yWZGqf#UoN(30PWI-BgBtXW91LZ+fiX=mgIFhhpiWEbQ7>WSQS^`!|Pp)8*W0wJ0
zoFY4UE{i&&+~lJyP0X3wm^Zt!zGD$)W@BVy<N`uDMtNpMMrCGx=FM(g4;Y!vco+ei
C9x}@S

delta 269
zcmZ1{u~LGa`7<w<z{{5#*`1gec_w=?Ete8ts9|9+tCeP`VPwb>WvF3isNrBpXJSYZ
zoXpK!X(q&w!oyG_%upl3AOaE<W=P>?NZ|#m5Me0cV#wlV$P!|x;bce=WoTvsB4)-E
zvB|rbqf*2fN(30PWI-BgBtXW91LZ+fiUdQAIFhhpiX=mg7>WSQS^`!|O|D>(W0wY5
yoFX%ME{i&&?Bt^?O`CmK->?WXu`#kSasi<nqdcP`qcRge(`Ik32aL>yJd6N)k1<>T

diff --git a/classifiers/train_classifiers.py b/classifiers/train_classifiers.py
index 6367807..d2edabb 100644
--- a/classifiers/train_classifiers.py
+++ b/classifiers/train_classifiers.py
@@ -6,35 +6,59 @@ import numpy as np
 import tensorflow as tf
 import object_classifiers.train_obj_classifier as obj_trainer
 import object_classifiers.eval_obj_classifier as obj_evaluator
+import attribute_classifiers.train_atr_classifier as atr_trainer
+import attribute_classifiers.eval_atr_classifier as atr_evaluator
 
 workflow = {
     'train_obj': False,
     'eval_obj': False,
+    'train_atr': True,
+    'eval_atr': True,
 }
 
-train_params = {
+obj_classifier_train_params = {
     'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier',
     'adam_lr': 0.001,
     'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
     'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
 }
 
-eval_params = {
+obj_classifier_eval_params = {
     'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier',
     'model_name': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier/obj_classifier',
     'global_step': 1,
     'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
     'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
     'html_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier/html_dir',
-    'create_graph': False,
+}
+
+atr_classifier_train_params = {
+    'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier',
+    'adam_lr': 0.001,
+    'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
+    'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
+    'obj_model_name': obj_classifier_eval_params['model_name'],
+    'obj_global_step': 1,
+}
+
+atr_classifier_eval_params = {
+    'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier',
+    'model_name': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/atr_classifier',
+    'global_step': 1,
+    'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
+    'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
+    'html_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/html_dir',
 }
 
 if __name__=='__main__':
     if workflow['train_obj']:
-        obj_trainer.train(train_params)
+        obj_trainer.train(obj_classifier_train_params)
  
-    obj_feat = tf.get_collection('obj_feat', scope='obj/conv2')
-    print(obj_feat)
     if workflow['eval_obj']:
-        obj_evaluator.eval(eval_params)
+        obj_evaluator.eval(obj_classifier_eval_params)
+
+    if workflow['train_atr']:
+        atr_trainer.train(atr_classifier_train_params)
 
+    if workflow['eval_atr']:
+        atr_evaluator.eval(atr_classifier_eval_params)
-- 
GitLab