From 7e913c61163eb4e3ddce7c00c78373a94847366b Mon Sep 17 00:00:00 2001
From: tgupta6 <tgupta6@illinois.edu>
Date: Wed, 30 Mar 2016 15:09:10 -0500
Subject: [PATCH] about to merge with main

---
 .gitignore                                    |   1 +
 .../answer_classifier/ans_data_io_helper.py   | 152 ++++++++++--------
 .../answer_classifier/ans_data_io_helper.pyc  | Bin 6022 -> 8092 bytes
 .../answer_classifier/eval_ans_classifier.py  | 127 ++++++++++++---
 .../answer_classifier/train_ans_classifier.py |  69 ++++----
 .../train_ans_classifier.pyc                  | Bin 8282 -> 8282 bytes
 .../#obj_data_io_helper.py#                   |  84 ++++++++++
 classifiers/region_ranker/perfect_ranker.py   |   2 -
 classifiers/region_ranker/perfect_ranker.pyc  | Bin 3307 -> 3307 bytes
 classifiers/tf_graph_creation_helper.py       |   3 +-
 classifiers/tf_graph_creation_helper.pyc      | Bin 11829 -> 11771 bytes
 classifiers/train_classifiers.py              |   4 +-
 object_classifiers/#eval_obj_classifier.py#   |  45 ++++++
 13 files changed, 364 insertions(+), 123 deletions(-)
 create mode 100644 classifiers/object_classifiers/#obj_data_io_helper.py#
 create mode 100644 object_classifiers/#eval_obj_classifier.py#

diff --git a/.gitignore b/.gitignore
index 6fbbab7..f2d66e7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
 *~
+shapes_dataset/images_old
 shapes_dataset/images
 shapes_dataset/*.json
\ No newline at end of file
diff --git a/classifiers/answer_classifier/ans_data_io_helper.py b/classifiers/answer_classifier/ans_data_io_helper.py
index 50317d1..a5f0c7c 100644
--- a/classifiers/answer_classifier/ans_data_io_helper.py
+++ b/classifiers/answer_classifier/ans_data_io_helper.py
@@ -2,6 +2,7 @@ import json
 import sys
 import os
 import time
+import random
 import matplotlib.pyplot as plt
 import matplotlib.image as mpimg
 import numpy as np
@@ -99,77 +100,100 @@ def save_regions(image_dir, out_dir, qa_dict, region_anno_dict, start_id,
         
             image_done[image_id-1] = True
 
+class batch_creator():
+    def __init__(self, start_id, end_id):
+        self.start_id = start_id
+        self.end_id = end_id
+        self.id_list = range(start_id, end_id+1)
 
-def ans_mini_batch_loader(qa_dict, region_anno_dict, ans_dict, vocab, 
-                          image_dir, mean_image, start_index, batch_size, 
-                          img_height=100, img_width=100, channels = 3):
-    
-    ans_labels = np.zeros(shape=[batch_size, len(ans_dict)])
-    for i in xrange(start_index, start_index + batch_size):
-        answer = qa_dict[i].answer
-        ans_labels[i-start_index, ans_dict[answer]] = 1
-        
-    # number of regions in the batch
-    count = batch_size*num_proposals;
-    region_shape = np.array([img_height/3, img_width/3], np.int32)
-    region_images = np.zeros(shape=[count, region_shape[0], 
-                                    region_shape[1], channels])
-    region_score = np.zeros(shape=[1,count])
-    partition = np.zeros(shape=[count])
-    question_encodings = np.zeros(shape=[count, len(vocab)])
+    def shuffle_ids(self):
+        random.shuffle(self.id_list)
 
-    for i in xrange(start_index, start_index + batch_size):
-        
-        image_id = qa_dict[i].image_id
-        question = qa_dict[i].question
-        answer = qa_dict[i].answer
-        gt_regions_for_image = region_anno_dict[image_id]
-        start1 = time.time()
-        regions = region_proposer.rank_regions(None, question,
-                                               region_coords, region_coords_,
-                                               gt_regions_for_image,
-                                               False)
-
-        end1 = time.time()
-#        print('Ranking Region: ' + str(end1-start1))
-        question_encoding_tmp = np.zeros(shape=[1, len(vocab)])
-        
-        for word in question[0:-1].split():
+    def qa_index(self, start_index, batch_size):
+        return self.id_list[start_index - self.start_id 
+                            : start_index - self.start_id + batch_size]
+
+    def ans_mini_batch_loader(self, qa_dict, region_anno_dict, ans_dict, vocab, 
+                              image_dir, mean_image, start_index, batch_size, 
+                              img_height=100, img_width=100, channels = 3):
+
+        q_ids = self.qa_index(start_index, batch_size)
+
+        ans_labels = np.zeros(shape=[batch_size, len(ans_dict)])
+        for i in xrange(batch_size):
+            q_id = q_ids[i]
+            answer = qa_dict[q_id].answer
+            ans_labels[i, ans_dict[answer]] = 1
+            
+        # number of regions in the batch
+        count = batch_size*num_proposals;
+        region_shape = np.array([img_height/3, img_width/3], np.int32)
+        region_images = np.zeros(shape=[count, region_shape[0], 
+                                        region_shape[1], channels])
+        region_score = np.zeros(shape=[1,count])
+        partition = np.zeros(shape=[count])
+        question_encodings = np.zeros(shape=[count, len(vocab)])
             
+        for i in xrange(batch_size):
+            q_id = q_ids[i]
+            image_id = qa_dict[q_id].image_id
+            question = qa_dict[q_id].question
+            answer = qa_dict[q_id].answer
+            gt_regions_for_image = region_anno_dict[image_id]
+            regions = region_proposer.rank_regions(None, question,
+                                                   region_coords, 
+                                                   region_coords_,
+                                                   gt_regions_for_image,
+                                                   False)
+            
+            question_encoding_tmp = np.zeros(shape=[1, len(vocab)])
+            for word in question[0:-1].split():
                 if word not in vocab:
                     word = 'unk'
                 question_encoding_tmp[0, vocab[word]] += 1 
-        question_len = np.sum(question_encoding_tmp)
-        # print(question[0:-1].split())
-        # print(question_len)
-        # print(question_encoding_tmp)
-        # print(vocab)
-        assert (not question_len==0)
-
-        question_encoding_tmp /= question_len
+                    
+            question_len = np.sum(question_encoding_tmp)
+            assert (not question_len==0)
+            question_encoding_tmp /= question_len
         
-        for j in xrange(num_proposals):
-            counter = j + (i-start_index)*num_proposals
-            
-            proposal = regions[j]
-
-            start2 = time.time()
-            resized_region = mpimg.imread(os.path.join(image_dir,
-                                                       '{}_{}.png'
-                                                       .format(image_id,j)))
-            end2 = time.time()
-#            print('Reading Region: ' + str(end2-start2))
-            region_images[counter,:,:,:] = (resized_region / 254.0) \
-                                           - mean_image
-            region_score[0,counter] = proposal.score
-            partition[counter] = i-start_index
-            
-            question_encodings[counter,:] = question_encoding_tmp
-
-        score_start_id = (i-start_index)*num_proposals
-        region_score[0, score_start_id:score_start_id+num_proposals] /= \
-            np.sum(region_score[0,score_start_id:score_start_id+num_proposals])
-    return region_images, ans_labels, question_encodings, region_score, partition
+            for j in xrange(num_proposals):
+                counter = j + i*num_proposals
+                proposal = regions[j]
+                resized_region = mpimg.imread(os.path.join(image_dir, 
+                                             '{}_{}.png'.format(image_id,j)))
+                region_images[counter,:,:,:] = (resized_region / 254.0) \
+                                               - mean_image
+                region_score[0,counter] = proposal.score
+                partition[counter] = i
+                
+                question_encodings[counter,:] = question_encoding_tmp
+                
+            score_start_id = i*num_proposals
+            region_score[0, score_start_id:score_start_id+num_proposals] /=\
+                np.sum(region_score[0,score_start_id
+                                        : score_start_id+num_proposals])
+        return region_images, ans_labels, question_encodings, \
+            region_score, partition
+
+
+class html_ans_table_writer():
+    def __init__(self, filename):
+        self.filename = filename
+        self.html_file = open(self.filename, 'w')
+        self.html_file.write("""<!DOCTYPE html>\n<html>\n<body>\n<table border="1" style="width:100%"> \n""")
+    
+    def add_element(self, col_dict):
+        self.html_file.write('    <tr>\n')
+        for key in range(len(col_dict)):
+            self.html_file.write("""    <td>{}</td>\n""".format(col_dict[key]))
+        self.html_file.write('    </tr>\n')
+
+    def image_tag(self, image_path, height, width):
+        return """<img src="{}" alt="IMAGE NOT FOUND!" height={} width={}>""".format(image_path,height,width)
+        
+    def close_file(self):
+        self.html_file.write('</table>\n</body>\n</html>')
+        self.html_file.close()
 
     
 if __name__=='__main__':
diff --git a/classifiers/answer_classifier/ans_data_io_helper.pyc b/classifiers/answer_classifier/ans_data_io_helper.pyc
index a82756a940860691e4e9f2f7fbb6be98c2d944f8..31700abb56a4f4321672b3149d430c2bb7291e78 100644
GIT binary patch
delta 3337
zcma)8y>lDa5#M(pKmh#0Hwb=+ASlVCEsLZaSsI0;980pDj7@ruvN^g)hC$pB#0elx
z+$o}>yTLe=Q>DpVm&tg>g_}5&hAHAEnYhcJktRi&6snRczuf~sS!n{?+uPl@@4fwc
zyZ6iG-`<>#JReKWeD$A?>(u*=;{6te^ZeJ-Yh7cAOd?ZJD@t~ZjNv{RCp$q#vQLhX
zG1|w+$QTz~qt*nqQ)Hw_@h%Z+rKz1EBNNitHA{Amj67KxGI3OtPBbzKWa6kIS@|y6
zlVsx1VV<~;xe^&uWX9<kkugnXLXa6UlY-2WIU>k8GNBJ7$Cx}Jx@XLh3AM0hV!$es
zHBV-W9%|J2mon32W~gXXh_NiNg|47Epy$aNCHn%Ic`^%BJW*s^Bx{DO3Nii)lqGwS
ztOX&L$eQL;S>sp9c#W)CvR>gS=3Ela>&&qr`6^kL$ie{^$t=nPoAH$B<mL$7kCJhj
zjAb%b$XF%g3R%l!zd`T5XRY+!mwWG3W?lg}kdSRJ&YV717c_M~C=Gq7@~3|q`bMdx
zQ#1BwrEUca@$b~t(|^Q&Go;=NJ|DZ@b;l6)-I`n9uhu(O%{}Z4l41xmZ*x*y1+d_H
z#R29RXq+63M99<xK^TGKk$66fkP#()dohNU^#}kD-tQKMa|ysx<PlY&w5JKC$OQtf
zY+gj<c^*3wZ7*tAvDs9Bi@b~eK7h-cTL)&fY3k!l4+AvKYOCqEi+rBTqK?(t6LH-D
z6?_%H){O(HR+|S+w_4q25#ESY0{=Q6!GjMJU2{NOB3sn8HI62y_4Xo5ZHDLY`cCb@
zJZy^&&i<ply_Ti_6l@XG?=yxWAoyMUa(4toPY(J!Z6^6lQvd6&GT88i-FF&vnGD@0
zm`26<zB;^xoGT`ac5=8Y=ef0x3%gD0IM~WW2fjF}h0M>e*kPZMN~tD{q$azu&c$H?
zBZ(nCc^&MYLc`?6`Ih>l<SK%HkQm^&5%MDJs|Gwq${(ZSn;Z`oR&@KpD*5B&jZ)Ve
zr;<M*ai5@)9VI`-Ddmk3X0U~qqSBs5Ye(y}!~4c4=cnZ;oSfi{%X_Ftc;O`0Ak-%F
zf;MULq|TK1rBTTnp~jGCm8I@LX+AXFS2y>)<qaj$`R_5=LhLfIut3Md+6xx`TZI<B
zWMG2);=tM@b^XZ!CARZrD8eNrLda~XuHGcM<K)6gV5aCKM*h^9D$x67Ri^3qCxlhu
zU7WYmP#-?OO8yLac!ucVFm?W>P;a|n@fAdYL~!uTh`Z}>6E?mUSv&fJ+F>7MSO(Vm
zv*e-PVhfE%orLCQ=@*FjG;?qGIQd%+Gr}Zc{)YT><l%w-h+NU1BX5csCBIB=j_fz#
z)B@Q>uJ~S*E52P~2S5Z##Sf00Cl3WXj!4Il_*wGif}_-Yw@fzn%8BN4#7V*_lgBhL
zC{$n!S&ZikP3Div+D=auztUd{q{7-2Lm@W3XK=XnOY<dYbOU~wO{pE$1&NA-ayU3x
zO24CagNNzs!SB<J7~dHEevr$&7S(IOaph(&Mu8mhg?ITf)ef9$tF~*k96bY4>J`rh
z-OS#wewSJK;IElnkjq|=>N}tVD_ifP9Fg-W01aidb1JKnYD`T=l4=2XR!yqBI;Vx)
zG0Y^@sEWux0qH4F3CvIMI!KLUK7%#GYDqX#S`so7;3SlkZ#OKL2er0Ut+LwZQ9Rp+
z=A#y9{Q<PlSr>hm!6N`EbSN%wcJ+H88yF7n5mSk1B9e%?c@X<<yH#cT-P$fTdfaKc
z)>-G_+WI!H<4*Dabe?DEMbdDFumB;H0^XIbBNgikIK7Sor-rp$+9gs9m8>cd)ytNK
zh7gA}9(9g3Ax`wzxdXU1e|ziJ-FqM2DYO0?$+g~V_t1O-N^B|b9(GKtvtC*Kpi*|+
zCoOBe^0;Zb`!`lsRxVUF%E_>WeT@C|DF8PCApY2Ewk(%ZMBE`wPpY@hvY|)W-+#my
z^sKxRb6HbfDb{`hB&^S95GcrO&+#?xuUsmG2J7~N1@c3~rN80A+o%9=4z2*^dDYPZ
z%C58?-YL=*CN9*r*s$ra>Sg@a+|EYQNew7-<Eg*4?3!S5qKn8f%Pg{z`p+3WWx$PF
zuQOo(zR<8ZW9o;ks@bf&`V!A!1YgJBkk(l<&8pS1+SY;lE%=gO+_g$-Ue!4V`a}GY
z1p{ud!yJIgXibs+SA}yOvun+EqwI9*>y@W|rCe*d>y;mW@S}I{lsC8TmfzX>X!G`b
zrMz!78~g5^^{0MW0t#d!?C+vG9~m}VHo4iZHLU7qHFqDi1@V=&VQ@Ja5X~QULm9W$
z_zNWZ?Vq!a2Pd&;KM&+A24P%!G&k2?s<QZ|GYJU4^PNIIF8A_sS<*Cg?!fhKlxt4C
zb?8`UA1VAzaEsL!ew7pQJ5YnK0;vf4&%n|jGmz`=Gxj+HXmfV`3;Hq0Q;f6gA4^1n
z<=j74<-<Y>MF@VttET~wppbv!=qpUHOWkJM(my%9k{?y$Tu)?IQ6bnWtf@xutHN7T
pTzB*q0|c4$Lk4>c4uZcIK8k$^bW{a5iVM-X?EBeV@M-ble*xssJw*Tj

delta 1754
zcmZux%WoTH5dVE^$L~k{ek8GzIBC<SO`52@sz9SQv<N|v%p=k&spV$3X|}Pe#48k}
zU0K3`GiXmBgv6O!bMJuz7le@D0!PG+0|)pA_{}D1Oocr=J2T&W^Pc_c=BJBefq#;*
zzyAFDgKeDt5`3@n>HYHg*h=3TfgXYZ7z%q~SP|&aGd2c2enuytCw*GOUJ9Kw^fZ)Q
z3t(>)oecEM2~RUwSUKo<m>C#B7$X>H&<ij^(2Fqhuu3q(0t2m&O`esZS71bN1n5;5
zQJ)xt5%Y<07;&GNfRXTlk)&`GfL?=Hhmpo34F^9dV-!XPC4CY?Wl2r>RF2d%%s8wD
zj6942N&^Lb2Ig2F<|Krf^#i*AbINBg!mMHdKe{>Smtc;=oRuSq_p&dcbKaun9L&ow
z>0t&&(FZ0-NMaDjCJQ_d{R;G}&=;UD!n^|O8t%SfUWHZniTP8(f=CyMx_&?2ohu*b
zv{_9xyjOx>tNih=!EcqCJ3b0MRqB>^JMx{nejJK^8B}+@Z;~(fci9i@28g!fL^FJ1
zd~WmUy-%?XSV|i0gkfuNG|6X&7|M;~a20j{PRNHaOQd0KYaVU_4&WqEv4U_}d0vbQ
zPt7t-I7w6<Xjs`_?qfp?k09@+d|xybkpU?<94Wr&N<xsn!%Yu`lY*0$QK`sF<12=p
zpdZ%gxt2IiorRk@my1w|a^KDVM=B`>hEmKiMeTo5)P_=wz|B9)Hww4#jPa9qSi3wb
zZV`42R*YWZmSA(TIGBBwsQ{a~r{Pdz6oU}l@{nrfJU(Xge0)_Lz9|-12=2_3l@ar6
zTQ9&JgF^`5Q5Xk5D((>_91(}R5)?U4*@Hf|SvwkjDX_BtL~Tf|v!cW<yW?=UM?u+?
zL@%k?Ieg4qsv^DN((`Z9eD;pIYg<3Soq$6)>9Vx6<ksMnMObm`Ln%$d=0Y>=B3|WQ
zTU25A{E$?nQdDQJ1sKBA(6wne+|39_V}w&Q3a923b5s2~EQaKNHk=w{tj-3FiPXsP
zG$#}%o$!g~>+H{H8$+SGjWZJ@8j@7$L{Z9fT9gwP&z)Dw4D$wyms1-O6UUs+GH|lq
z_qjV>CVxF@2f0l4ZpZXq&+n)i?@4}smF?$(+3f?nwc9n!!{*!Mo9_r(7wrhq_QQ6!
zYwq<pHbA@EYwfk~lWk_ndb5QG6+1-AH?q`xkEHL&>>5knXNC2l)8*#I(CM=GM`63a
zBUkScAy`vYl~m(_m};mT|7BHBHLa>8DNU-FO7I_20TtB}n#dC>s%k1dlG9>ZPGyM`
zDoc&5N@!tqi561SO3~^?<tNf=b=te#R;xEjcZ&}ncg#ia=J?N*xh3+=F(S?p!}+M!
zye1O;_|`;1Wgg3c6r$DYwmW94<!#nh)IINu+G_Q_=(Pozg0=*)xo9qUzZYNgUaV(>
SZ;|!Z>y5y4_Kx>q{ldT812xwG

diff --git a/classifiers/answer_classifier/eval_ans_classifier.py b/classifiers/answer_classifier/eval_ans_classifier.py
index 35eb657..d6b7eb3 100644
--- a/classifiers/answer_classifier/eval_ans_classifier.py
+++ b/classifiers/answer_classifier/eval_ans_classifier.py
@@ -12,35 +12,33 @@ import plot_helper as plotter
 import ans_data_io_helper as ans_io_helper
 import region_ranker.perfect_ranker as region_proposer 
 import train_ans_classifier as ans_trainer
+from PIL import Image, ImageDraw
 
 def get_pred(y, qa_anno_dict, region_anno_dict, ans_vocab, vocab,
              image_dir, mean_image, start_index, val_set_size, batch_size,
-             placeholders, img_height=100, img_width=100):
+             placeholders, img_height, img_width, batch_creator):
 
     inv_ans_vocab = {v: k for k, v in ans_vocab.items()}
     pred_list = []
     correct = 0
     max_iter = int(math.ceil(val_set_size*1.0/batch_size))
-#    print ([val_set_size, batch_size])
-#    print('max_iter: ' + str(max_iter))
     batch_size_tmp = batch_size
     for i in xrange(max_iter):
         if i==(max_iter-1):
             batch_size_tmp = val_set_size - i*batch_size
+
         print('Iter: ' + str(i+1) + '/' + str(max_iter))
-#        print batch_size_tmp
+
         region_images, ans_labels, questions, \
-        region_score, partition= \
-            ans_io_helper.ans_mini_batch_loader(qa_anno_dict, 
-                                                region_anno_dict, 
-                                                ans_vocab, vocab, 
-                                                image_dir, mean_image, 
-                                                start_index+i*batch_size, 
-                                                batch_size_tmp, 
-                                                img_height, img_width, 3)
+        region_score, partition = batch_creator \
+            .ans_mini_batch_loader(qa_anno_dict, 
+                                   region_anno_dict, 
+                                   ans_vocab, vocab, 
+                                   image_dir, mean_image, 
+                                   start_index+i*batch_size, 
+                                   batch_size_tmp, 
+                                   img_height, img_width, 3)
             
-        # print [start_index+i*batch_size, 
-        #        start_index+i*batch_size + batch_size_tmp -1]
         if i==max_iter-1:
                                     
             residual_batch_size = batch_size - batch_size_tmp
@@ -63,10 +61,6 @@ def get_pred(y, qa_anno_dict, region_anno_dict, ans_vocab, vocab,
                                         axis=0)
             region_score = np.concatenate((region_score, residual_region_score),
                                           axis=1)
-            # print region_images.shape
-            # print questions.shape
-            # print ans_labels.shape
-            # print region_score.shape
 
         feed_dict = {
             placeholders[0] : region_images, 
@@ -82,8 +76,6 @@ def get_pred(y, qa_anno_dict, region_anno_dict, ans_vocab, vocab,
                 'question_id' : start_index+i*batch_size+j,
                 'answer' : inv_ans_vocab[ans_ids[j]]
             }]
-            # print qa_anno_dict[start_index+i*batch_size+j].question
-            # print inv_ans_vocab[ans_ids[j]]
 
     return pred_list
 
@@ -143,10 +135,15 @@ def eval(eval_params):
 
     placeholders = [image_regions, questions, keep_prob, y, region_score]
 
+    # Batch creator
+    test_batch_creator = ans_io_helper.batch_creator(test_start_id,
+                                                     test_start_id 
+                                                     + test_set_size - 1)
     # Get predictions
-    pred_dict =get_pred(y_avg, qa_anno_dict, region_anno_dict, ans_vocab, vocab,
+    pred_dict = get_pred(y_avg, qa_anno_dict, region_anno_dict, ans_vocab, vocab,
                         image_regions_dir, mean_image, test_start_id, 
-                        test_set_size, batch_size, placeholders, 75, 75)
+                        test_set_size, batch_size, placeholders, 75, 75,
+                        test_batch_creator)
 
     json_filename = os.path.join(outdir, 'predicted_ans_' + \
                                  eval_params['mode'] + '.json')
@@ -154,12 +151,89 @@ def eval(eval_params):
         json.dump(pred_dict, json_file)
 
     
+def create_html_file(outdir, test_anno_filename, regions_anno_filename,
+                     pred_json_filename, image_dir):
+    qa_dict = ans_io_helper.parse_qa_anno(test_anno_filename)
+    region_anno_dict = region_proposer.parse_region_anno(regions_anno_filename)
+    ans_vocab, inv_ans_vocab = ans_io_helper.create_ans_dict()
+
+    with open(pred_json_filename,'r') as json_file:
+        raw_data = json.load(json_file)
+    
+    # Create director for storing images with region boxes
+    images_bbox_dir = os.path.join(outdir, 'images_bbox')
+    if not os.path.exists(images_bbox_dir):
+        os.mkdir(images_bbox_dir)
+    
+    col_dict = {
+        0 : 'Question_Id',
+        1 : 'Question',
+        2 : 'Answer (GT)',
+        3 : 'Answer (Pred)',
+        4 : 'Image',
+    }
+    html_correct_filename = os.path.join(outdir, 'correct_ans.html')
+    html_writer_correct = ans_io_helper \
+        .html_ans_table_writer(html_correct_filename)
+    html_writer_correct.add_element(col_dict)
+
+    html_incorrect_filename = os.path.join(outdir, 'incorrect_ans.html')
+    html_writer_incorrect = ans_io_helper \
+        .html_ans_table_writer(html_incorrect_filename)
+    html_writer_incorrect.add_element(col_dict)
+
+    region_coords, region_coords_ = region_proposer.get_region_coords(300,300)
+
+    for entry in raw_data:
+        q_id = entry['question_id']
+        pred_ans = entry['answer']
+        gt_ans = qa_dict[q_id].answer
+        question = qa_dict[q_id].question
+        img_id = qa_dict[q_id].image_id
+        image_filename = os.path.join(image_dir, str(img_id) + '.jpg')
+        image = Image.open(image_filename)
+        
+        regions = region_proposer.rank_regions(image, question, region_coords, 
+                                               region_coords_, 
+                                               region_anno_dict[img_id],
+                                               crop=False)
+        dr = ImageDraw.Draw(image)
+        # print(q_id)
+        # print([regions[key].score for key in regions.keys()])
+        for i in xrange(ans_io_helper.num_proposals):
+            if not regions[i].score==0:
+                coord = regions[i].coord
+                x1 = coord[0]
+                y1 = coord[1]
+                x2 = coord[2]
+                y2 = coord[3]
+                dr.rectangle([(x1,y1),(x2,y2)], outline="red")
+        
+        image_bbox_filename = os.path.join(images_bbox_dir,str(q_id) + '.jpg')
+        image.save(image_bbox_filename)
+        image_bbox_filename_rel = 'images_bbox/' + str(q_id) + '.jpg' 
+        col_dict = {
+            0 : q_id,
+            1 : question,
+            2 : gt_ans,
+            3 : pred_ans,
+            4 : html_writer_correct.image_tag(image_bbox_filename_rel,50,50)
+        }
+        if pred_ans==gt_ans:
+            html_writer_correct.add_element(col_dict)
+        else:
+            html_writer_incorrect.add_element(col_dict)
+
+    html_writer_correct.close_file()
+    html_writer_incorrect.close_file()
+    
 
 if __name__=='__main__':
     ans_classifier_eval_params = {
         'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
         'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
         'regions_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/regions_anno.json',
+        'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
         'image_regions_dir': '/mnt/ramdisk/image_regions',
         'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier',
         'model': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier/ans_classifier_q_obj_atr-9',
@@ -169,4 +243,11 @@ if __name__=='__main__':
         'test_set_size': 160725-111352+1,
     }
 
-    eval(ans_classifier_eval_params)
+#    eval(ans_classifier_eval_params)
+    outdir = ans_classifier_eval_params['outdir']
+    test_anno_filename = ans_classifier_eval_params['test_json']
+    regions_anno_filename = ans_classifier_eval_params['regions_json']
+    pred_json_filename = os.path.join(outdir, 'predicted_ans_q.json')
+    image_dir = ans_classifier_eval_params['image_dir']
+    create_html_file(outdir, test_anno_filename, regions_anno_filename,
+                     pred_json_filename, image_dir)
diff --git a/classifiers/answer_classifier/train_ans_classifier.py b/classifiers/answer_classifier/train_ans_classifier.py
index 524d294..764d4bc 100644
--- a/classifiers/answer_classifier/train_ans_classifier.py
+++ b/classifiers/answer_classifier/train_ans_classifier.py
@@ -22,20 +22,17 @@ val_set_size_small = 100
 
 def evaluate(accuracy, qa_anno_dict, region_anno_dict, ans_vocab, vocab,
              image_dir, mean_image, start_index, val_set_size, batch_size,
-             placeholders, img_height=100, img_width=100):
+             placeholders, img_height, img_width, batch_creator):
     
     correct = 0
     max_iter = int(math.floor(val_set_size/batch_size))
     for i in xrange(max_iter):
         region_images, ans_labels, questions, \
-        region_score, partition= \
-            ans_io_helper.ans_mini_batch_loader(qa_anno_dict, 
-                                                region_anno_dict, 
-                                                ans_vocab, vocab, 
-                                                image_dir, mean_image, 
-                                                start_index+i*batch_size, 
-                                                batch_size, 
-                                                img_height, img_width, 3)
+        region_score, partition= batch_creator \
+            .ans_mini_batch_loader(qa_anno_dict, region_anno_dict, 
+                                   ans_vocab, vocab, image_dir, mean_image, 
+                                   start_index+i*batch_size, batch_size, 
+                                   img_height, img_width, 3)
             
         feed_dict = {
             placeholders[0] : region_images, 
@@ -224,32 +221,43 @@ def train(train_params):
 
     placeholders = [image_regions, questions, keep_prob, y, region_score]
 
+
+    # Start Training
+    max_epoch = train_params['max_epoch']
+    max_iter = 5000
+    val_acc_array_epoch = np.zeros([max_epoch])
+    train_acc_array_epoch = np.zeros([max_epoch])
+
+    # Batch creators
+    train_batch_creator = ans_io_helper.batch_creator(1, max_iter*batch_size)
+    val_batch_creator = ans_io_helper.batch_creator(val_start_id, val_start_id 
+                                                    + val_set_size - 1)
+    val_small_batch_creator = ans_io_helper.batch_creator(val_start_id, 
+                                                          val_start_id + 
+                                                          val_set_size_small-1)
+
+    # Check accuracy of restored model
     if train_params['fine_tune']==True:
         restored_accuracy = evaluate(accuracy, qa_anno_dict, 
                                      region_anno_dict, ans_vocab, 
                                      vocab, image_regions_dir, 
                                      mean_image, val_start_id, 
                                      val_set_size, batch_size,
-                                     placeholders, 75, 75)
+                                     placeholders, 75, 75,
+                                     val_batch_creator)
         print('Accuracy of restored model: ' + str(restored_accuracy))
-
-    # Start Training
-    max_epoch = train_params['max_epoch']
-    max_iter = 5000
-    val_acc_array_epoch = np.zeros([max_epoch])
-    train_acc_array_epoch = np.zeros([max_epoch])
+    
     for epoch in range(start_epoch, max_epoch):
-        iter_ids = range(max_iter)
-        random.shuffle(iter_ids)
-        for i in iter_ids: #range(max_iter):
+        train_batch_creator.shuffle_ids()
+        for i in range(max_iter):
         
             train_region_images, train_ans_labels, train_questions, \
-            train_region_score, train_partition= \
-            ans_io_helper.ans_mini_batch_loader(qa_anno_dict, region_anno_dict, 
-                                                ans_vocab, vocab, 
-                                                image_regions_dir, mean_image, 
-                                                1+i*batch_size, batch_size, 
-                                                75, 75, 3)
+            train_region_score, train_partition= train_batch_creator \
+                .ans_mini_batch_loader(qa_anno_dict, region_anno_dict, 
+                                       ans_vocab, vocab, 
+                                       image_regions_dir, mean_image, 
+                                       1+i*batch_size, batch_size, 
+                                       75, 75, 3)
                         
             feed_dict_train = {
                 image_regions : train_region_images, 
@@ -258,7 +266,7 @@ def train(train_params):
                 y: train_ans_labels,        
                 region_score: train_region_score,
             }
-
+            
             if pretrained_vars_low_lr:
                 _, _, current_train_batch_acc, y_pred_eval, loss_eval = \
                     sess.run([train_step_low_lr, train_step_high_lr,
@@ -280,10 +288,10 @@ def train(train_params):
                                         region_anno_dict, ans_vocab, vocab,
                                         image_regions_dir, mean_image, 
                                         val_start_id, val_set_size_small,
-                                        batch_size, placeholders, 75, 75)
+                                        batch_size, placeholders, 75, 75,
+                                        val_small_batch_creator)
                 
-                print('Iter: ' + str(i+1) + ' Val Sm Acc: ' + str(val_accuracy) 
-                      + ' Loss: ' + str(loss_eval))
+                print('Iter: ' + str(i+1) + ' Val Sm Acc: ' + str(val_accuracy))
 
         train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] / max_iter
         val_acc_array_epoch[epoch] = evaluate(accuracy, qa_anno_dict, 
@@ -291,7 +299,8 @@ def train(train_params):
                                               vocab, image_regions_dir, 
                                               mean_image, val_start_id, 
                                               val_set_size, batch_size,
-                                              placeholders, 75, 75)
+                                              placeholders, 75, 75,
+                                              val_batch_creator)
 
         print('Val Acc: ' + str(val_acc_array_epoch[epoch]) + 
               ' Train Acc: ' + str(train_acc_array_epoch[epoch]))
diff --git a/classifiers/answer_classifier/train_ans_classifier.pyc b/classifiers/answer_classifier/train_ans_classifier.pyc
index 7a476b58d08ede1221bca8d1d9e83afa1f34f095..ee2dbe19a14f6d26908de083a00f417fc54d44d6 100644
GIT binary patch
delta 1834
zcmZ`(OKcle6g_YJN$uF4+D<Zc(=<&2Y8ukeuk&d>Zjv@`)2fvhsamM4sb`wlO`O)A
zR-t-BBm$@l7N94DM3+d2RTuRR*s(xjfdmT{Wy206B%}(_1>$$_b3&UUJoB#K-2XfG
zzWXTmDAyVIH6H!zv)>m=*!&svpQBZ`-s^n4CxkbFD}8Xndh-I@7vZ$vhJjjI4T=8U
z2~xFa#c9p3*|$U-0Y9N`id|(BYoEXiK$_SPn7#TT{GE_-RDzHK(tsox?L&}Z$QDSF
zM$l#vk*$z1I0iOa@oETu5`GGP8h#u6c9h8?2^r;BnXF8!$D;jMN;3>*wCQb_wtfd>
z0urc@Ne8E4`3c+_#3g2%YRniEVYO%ssE>?3VW}^TuabP4YqnQfDwb-l?XA|-yMecZ
z2Z*bF?jD+T@7T+?ZO`R{*Y^=YFf1}6ZG=Tw#6&{uH3Fh_^GO6Oqa;Y@JgpQh4b#Vs
z&x2h;TxdKB-4CsgAmeA@jKUv=H-O3@WEP8pll4db!x@DaM&%%!DCDrlYr}v)w#oH|
z!0*ejNWP0DC$>utJ(Y|zF3y2CGG0s<^{p1~u=+Y~IobnP;W!+sLJaZ*HWKhB;837G
z$O(<)T)qmdg#+FZ$N^58;>}+D&G087$51(n4M42DBUnD%Y=PH>Kc$7o)Hm^Thz~KP
zF2*wJIrt~xPs5*qKdTFzgG?bSry<kGI+Sc1eVv4Cr@t94X4XFi*^ipNgSx5G4|y2!
z5G19y53(1FAvhC|rzo)C7CNopIrw?_XHY%JJ;+wtMEw+t+fCKrHa#xLab1KgWS$Q}
zZ>-+f;SIx?<c~L~l}8~bk&$z-n%i+X#RamQDRek#4erBDa2E(%3i&L!;~LzvaB>>l
zrYU*0oRhE^?z0@nn#<pM*h256t#HF=-wfBR+xq|1O^$GS<Ic2puRk){{Y=w+Gy7%;
z-4g$H#CH+TKqY`mFW*mRR(FhvjOCny#oaZJY7e$_)}L28y`|)hpfU<)P6Igxf`k%c
zfXIi&DbE-_W%oE|Kv_S_Ab6-*=wgt?JWHIU-k*b<r$!44IjDX}^v2J#>kUKr=Wy$N
zWtyGhqI$_ZDq2<192REd1G6uX_9Ao@SM0T-dvmq4RKAy{yZvdZNBxq@tLgNUl*co`
zuGWj?)#8%7eABHpCfX*At6qrsciq}*y+9_QdjrroDAql@<`v6MfhVHJ_gr&buUIeN
zb&K^CdwIFg#WHqKv#X1)*Gjwk((UWlmtC@~7lP`Gjsq%SEr|2#inVxx<%J$5y-fO;
zm`qxj>|nyD(5aTIcZ!_uoz;?ktq^VeVMT-pH^P~$h;WXDokY|tyC=#z?V4S#7PpN!
zFG&iV^L8%JIDHa6z;iOTsJV5TkFHa+OQqX2yY!kqp9WPu*!^d7Xnx;05WP&wM`+dW
zDRa-0p<t9|WlE$19U>x<B4Q*1J88!gl%EJaDWgL~MLIy=YJb<!_84jP8>wwtado5X
zL0g6V7I?<eVAS9`t1D{%-XFyUWp*DDr&X?dq%Y6%Ad~WnTUeouu2q317<6R{m(~65
SeOqJxdc=ryM*haGnD__CX@u|q

delta 1901
zcmbVN&2JM&6#vbxKjOq6iJf?BAmocmNdyfPa6TOpl3)jDR2CIDlv-|fNp^8;@GdAt
zvx?NTRX<K>ZlK<J=&3#AkEqmBKW<!lta>YIRf$8@zPAnqgwzAxnK$pfnfK;>&GVy$
zM+*n_KT@45cb;5X$No>Jb%CbWfB)d)&Il|TQGBGSZefUm8Wp2EDca7{<aLRU{zSkM
z2vYC?U%=HMJ0KJ2+ZCV@q_O=qd>t~)<nr|g2r}>^aHEg{l8)0}1d_-YWE?WX(!B-&
z*#((|rD3-V??w@HV<+fAkVTL~&<lq|x*<FH{z6e^LsTb}>}GD1xjiAOF`5g}7@<KQ
zWC{{+>B>H~%OD4CAH`K>r_pzLG){i4IRxNvSR6rM18Pmninyw2<2@XKTEnY4&FZFI
zzh$@7=h_jGSKn&it1tAAkH16v6PwK~d%~|ZwrU?tEHy2AqHH&=Tv~j(tk-Lv=iG4Y
zmN%gu9~@JqNOtxJ2_g~`1(DMeLKiXmx@bMD>9n@Lv~FnYY|;gqJv75yxO<;O`h|FV
z@2lv;XnG6<KZ%<fT9LSaRQ;GTc8<Y+gQF`?!=k_@v716rgf#_!1hN<gqFiZ!Lt{;&
z5KO|FK|tH%a3>%qvC&TI^h=x>Y#X|N9CDnG?P`?im>$f6SyN#l%wf_$0rySFIi|(-
zeL>jr_0L35f**r>lJnEfIpu(>;S?<5l+=64!p=N`GYA$CEFw4?=6)WO^P*e;WwvO|
zLG}`Lh7($pi;#H~gC)pOwE8=!(%ezV6Od#MxjG0rjEyL)666wDBe=Yl!?A+k9D*_$
z$GARODimq#%g~efPbTt$0LSQ57?;zK=h#IWNw3-AQ-tT4_>k&-267e!xdNk&&dCxx
zX;^2Fw{l^+8vE(0QMx8yN>>HeLYS_$9H}*|MHrmG3SVMaXWLXcq{@}sjW~^~Pj8+=
z-~P@2QxtNV&78BA!ka>(5b9oT>u!6&oN|S0<yoygNcb9D9quqcW@{xZr3?y&bq)rV
zMX-v-5S}};^TMI>(?k9g+-X>qFrYJ_(2zkKDAEzeYn5a9j6GzX2U-0*<3d<HRI?;;
zk$0?~5Pu7D6}yNDc@*xf`Yk;$w#M3bH4&W0?Fz+;5>G`62x@9>Da0(`^5`XXB6CtC
zm7SRs-FshVMs)Rab|~j3=m~DswyXB7=K7}dX%0|-W-AG9B7+y3o~NdBPoido$i3y>
zd97mh6H2cWxLd1NJ-^oStBz%6m>VT<&u%rn5Np*M8#cWTfH14M<+Dz2^UjSMb=!=p
zU-D!5EOA0LwqNy}dv?{^s@3Z%Zd?@?RMps+<$KH_2Ez<S7<4n}V8FJUodkStX-?yA
zmD}3g=6daYvujVhmJlMo*HtKr1n-;NZPX73^6LD7Ql4!MefC+E<J&EA!&AQ>`1@7*
z-+1RW(sG)nx2;}3^w+S~EfQh>;~tn1S-q3iEPZ*A6e*F4B-LvF$)2oUq_+I*WKP}h
z|0d1N+vFZ;QbU8feE3(fre+6@i-@WZOpTO@Nq<8GoGshD$>%uyrpJKOV_s9=4Gg}D
TA7iR_u#|{o4ki9UGAaHArR0p3

diff --git a/classifiers/object_classifiers/#obj_data_io_helper.py# b/classifiers/object_classifiers/#obj_data_io_helper.py#
new file mode 100644
index 0000000..caab72a
--- /dev/null
+++ b/classifiers/object_classifiers/#obj_data_io_helper.py#
@@ -0,0 +1,84 @@
+#Embedded file name: /home/tanmay/Code/GenVQA/GenVQA/classifiers/object_classifiers/obj_data_io_helper.py
+import json
+import sys
+import os
+import matplotlib.pyplot as plt
+import matplotlib.image as mpimg
+import numpy as np
+import tensorflow as tf
+from scipy import misc
+
+def obj_mini_batch_loader(json_filename, image_dir, mean_image, start_index, batch_size, img_height = 100, img_width = 100, channels = 3):
+    with open(json_filename, 'r') as json_file:
+        json_data = json.load(json_file)
+        obj_images = np.empty(shape=[9 * batch_size,      img_height / 3,
+     img_width / 3,
+     channels])
+    obj_labels = np.zeros(shape=[9 * batch_size, 4])
+    for i in range(start_index, start_index + batch_size):
+        image_name = os.path.join(image_dir, str(i) + '.jpg')
+        image = misc.imresize(mpimg.imread(image_name), (img_height, img_width), interp='nearest')
+        crop_shape = np.array([image.shape[0], image.shape[1]]) / 3
+        selected_anno = [ q for q in json_data if q['image_id'] == i ]
+        grid_config = selected_anno[0]['config']
+        counter = 0
+        for grid_row in range(0, 3):
+            for grid_col in range(0, 3):
+                start_row = grid_row * crop_shape[0]
+                start_col = grid_col * crop_shape[1]
+                cropped_image = image[start_row:start_row + crop_shape[0], start_col:start_col + crop_shape[1], :]
+                if np.ndim(mean_image) == 0:
+                    obj_images[9 * (i - start_index) + counter, :, :, :] = cropped_image / 254.0
+                else:
+                    obj_images[9 * (i - start_index) + counter, :, :, :] = (cropped_image / 254.0) - mean_image
+                obj_labels[9 * (i - start_index) + counter, grid_config[6 * grid_row + 2 * grid_col]] = 1
+                counter = counter + 1
+
+    return (obj_images, obj_labels)
+
+
+def mean_image_batch(json_filename, image_dir, start_index, batch_size, img_height = 100, img_width = 100, channels = 3):
+    batch = obj_mini_batch_loader(json_filename, image_dir, np.empty([]), start_index, batch_size, img_height, img_width, channels)
+    mean_image = np.mean(batch[0], 0)
+    return mean_image
+
+
+def mean_image(json_filename, image_dir, num_images, batch_size, img_height = 100, img_width = 100, channels = 3):
+    max_iter = np.floor(num_images / batch_size)
+    mean_image = np.zeros([img_height / 3, img_width / 3, channels])
+    for i in range(max_iter.astype(np.int16)):
+        mean_image = mean_image + mean_image_batch(json_filename, image_dir, 1 + i * batch_size, batch_size, img_height, img_width, channels)
+
+    mean_image = mean_image / max_iter
+    tmp_mean_image = mean_image * 254
+    return mean_image
+
+
+class html_obj_table_writer:
+
+    def __init__(self, filename):
+        self.filename = filename
+        self.html_file = open(self.filename, 'w')
+        self.html_file.write('<!DOCTYPE html>\n<html>\n<body>\n<table border="1" style="width:100%"> \n')
+
+    def add_element(self, col_dict):
+        self.html_file.write('    <tr>\n')
+        for key in range(len(col_dict)):
+            self.html_file.write('    <td>{}</td>\n'.format(col_dict[key]))
+
+        self.html_file.write('    </tr>\n')
+
+    def image_tag(self, image_path, height, width):
+        return '<img src="{}" alt="IMAGE NOT FOUND!" height={} width={}>'.format(image_path, height, width)
+
+    def close_file(self):
+        self.html_file.write('</table>\n</body>\n</html>')
+        self.html_file.close()
+
+
+if __name__ == '__main__':
+    html_writer = html_obj_table_writer('/home/tanmay/Code/GenVQA/Exp_Results/Shape_Classifier_v_1/trial.html')
+    col_dict = {0: 'sam',
+     1: html_writer.image_tag('something.png', 25, 25)}
+    html_writer.add_element(col_dict)
+    html_writer.close_file()
diff --git a/classifiers/region_ranker/perfect_ranker.py b/classifiers/region_ranker/perfect_ranker.py
index 9d44e14..5ab52a0 100644
--- a/classifiers/region_ranker/perfect_ranker.py
+++ b/classifiers/region_ranker/perfect_ranker.py
@@ -8,7 +8,6 @@ import matplotlib.image as mpimg
 from scipy import misc
 region = namedtuple('region','image score coord')
 
-
 def parse_region_anno(json_filename):
     with open(json_filename,'r') as json_file:
         raw_data = json.load(json_file)
@@ -19,7 +18,6 @@ def parse_region_anno(json_filename):
         
     return region_anno_dict
 
-
 def get_region_coords(img_height, img_width):
     region_coords_ = np.array([[   1,     1,   100,   100],
                              [  101,    1,  200,  100],
diff --git a/classifiers/region_ranker/perfect_ranker.pyc b/classifiers/region_ranker/perfect_ranker.pyc
index 6807d1410576a03a277f3b0cf87f780fe3fe750f..763268d210df514c8542e4b876ade8918f75b1ad 100644
GIT binary patch
delta 61
zcmaDY`C5{l`7<xqpCi9EvR`6i<lcOXsf?LXbaNn^0Slw)<|eKnCUGWCE>1-rJsw6j
RCLTskej!FtW?n`CMgUR|4X6MB

delta 61
zcmaDY`C5{l`7<xq*Cih}vR`6i<k@_Ssf?LXd~+b10Slw~<|eKnCUIs?ZcarWJsw6j
RCLTskej!FtW?n`CMgU1)4SxUt

diff --git a/classifiers/tf_graph_creation_helper.py b/classifiers/tf_graph_creation_helper.py
index db49175..bab0ffd 100644
--- a/classifiers/tf_graph_creation_helper.py
+++ b/classifiers/tf_graph_creation_helper.py
@@ -282,8 +282,7 @@ def loss(y, y_pred):
     cross_entropy = -tf.reduce_sum(y * tf.log(y_pred_clipped), 
                                    name='cross_entropy')
     batch_size = tf.shape(y)
-    print 'Batch Size:' + str(tf.cast(batch_size[0],tf.float32))
-    return tf.truediv(cross_entropy, tf.cast(20,tf.float32))#batch_size[0],tf.float32))
+    return tf.truediv(cross_entropy, tf.cast(batch_size[0],tf.float32))
 
 
 if __name__ == '__main__':
diff --git a/classifiers/tf_graph_creation_helper.pyc b/classifiers/tf_graph_creation_helper.pyc
index 02d92a889c316357acd64391630440f3d31982be..367a32fb594471ea8c6b8dc46181d6581eb8ecc7 100644
GIT binary patch
delta 127
zcmdlQ^E;Z2`7<w<%7u+=FEtpmCV$ZIQY~R%$YN!vVP*g^*cfV97*bdnBtaZ@hGr%Z
z3f5qoT&t<cp2+|N8XS`kX_|?Zumjm8MWv}JnPow%fb1YEATe23YXUD9kiozx$|%Xm
S#mK*TlU5uvBgbSxJrw}%g&X$(

delta 185
zcmewzy)}l7`7<xq&8m%TFEtoDCV$ZIa!z4oC;=h{hAcLQ8Wx5WHU>!$hn=CBiJ_U1
zLEPN~#N}Y9VTLM7VFxoH3YZvzH8>{6Yii0Ab2Bh7I3<=OXD9?`R;5~H0=W#CB0xfe
zb8^3?nRHM%kR7xNL|6fd5_S+#RGONSSvL8$<^+BoAfJIzlu?qAi;<sEX>)~E95W;5
I<m<XB0HyyZ8~^|S

diff --git a/classifiers/train_classifiers.py b/classifiers/train_classifiers.py
index 3bf5366..fc3b0bb 100644
--- a/classifiers/train_classifiers.py
+++ b/classifiers/train_classifiers.py
@@ -73,8 +73,8 @@ ans_classifier_train_params = {
     'crop_n_save_regions': False,
     'max_epoch': 10,
     'batch_size': 20,
-    'fine_tune': False,
-    'start_model': 9,
+    'fine_tune': True,
+    'start_model': 3,
 }
 
 if __name__=='__main__':
diff --git a/object_classifiers/#eval_obj_classifier.py# b/object_classifiers/#eval_obj_classifier.py#
new file mode 100644
index 0000000..28fa50e
--- /dev/null
+++ b/object_classifiers/#eval_obj_classifier.py#
@@ -0,0 +1,45 @@
+import sys
+import os
+import matplotlib.pyplot as plt
+import matplotlib.image as mpimg
+import numpy as np
+import tensorflow as tf
+import obj_data_io_helper as shape_data_loader 
+from train_obj_classifier import placeholder_inputs, comp_graph_v_1, evaluation
+
+sess=tf.InteractiveSession()
+
+x, y, keep_prob = placeholder_inputs()
+y_pred = comp_graph_v_1(x, y, keep_prob)
+
+accuracy = evaluation(y, y_pred)
+
+saver = tf.train.Saver()
+
+saver.restore(sess, '/home/tanmay/Code/GenVQA/Exp_Results/Shape_Classifier_v_1/obj_classifier_9.ckpt')
+
+mean_image = np.load('/home/tanmay/Code/GenVQA/Exp_Results/Shape_Classifier_v_1/mean_image.npy')
+
+# Test Data
+test_json_filename = '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json'
+image_dir = '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images'
+
+# HTML file writer
+html_writer = shape_data_loader.html_obj_table_writer('/home/tanmay/Code/GenVQA/Exp_Results/Shape_Classifier_v_1/trial.html')
+batch_size = 100
+correct = 0
+for i in range(1): #50
+    test_batch = shape_data_loader.obj_mini_batch_loader(test_json_filename, image_dir, mean_image, 10000+i*batch_size, batch_size, 75, 75)
+    feed_dict_test={x: test_batch[0], y: test_batch[1], keep_prob: 1.0}
+    result = sess.run([accuracy, y_pred], feed_dict=feed_dict_test)
+    correct = correct + result[0]*batch_size
+    print(correct)
+
+    for row in range(batch_size):
+        col_dict = {
+            0: test_batch[1][row,:],
+            1: y_pred[row, :]}
+        html_writer.add_element(col_dict)
+
+html_writer.close_file()
+print('Test Accuracy: {}'.format(correct/5000))
-- 
GitLab