Skip to content
Snippets Groups Projects
Commit 7e913c61 authored by tgupta6's avatar tgupta6
Browse files

about to merge with main

parent edd34704
No related branches found
No related tags found
No related merge requests found
Showing
with 364 additions and 123 deletions
*~ *~
shapes_dataset/images_old
shapes_dataset/images shapes_dataset/images
shapes_dataset/*.json shapes_dataset/*.json
\ No newline at end of file
...@@ -2,6 +2,7 @@ import json ...@@ -2,6 +2,7 @@ import json
import sys import sys
import os import os
import time import time
import random
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.image as mpimg import matplotlib.image as mpimg
import numpy as np import numpy as np
...@@ -99,77 +100,100 @@ def save_regions(image_dir, out_dir, qa_dict, region_anno_dict, start_id, ...@@ -99,77 +100,100 @@ def save_regions(image_dir, out_dir, qa_dict, region_anno_dict, start_id,
image_done[image_id-1] = True image_done[image_id-1] = True
class batch_creator():
def __init__(self, start_id, end_id):
self.start_id = start_id
self.end_id = end_id
self.id_list = range(start_id, end_id+1)
def ans_mini_batch_loader(qa_dict, region_anno_dict, ans_dict, vocab, def shuffle_ids(self):
image_dir, mean_image, start_index, batch_size, random.shuffle(self.id_list)
img_height=100, img_width=100, channels = 3):
ans_labels = np.zeros(shape=[batch_size, len(ans_dict)])
for i in xrange(start_index, start_index + batch_size):
answer = qa_dict[i].answer
ans_labels[i-start_index, ans_dict[answer]] = 1
# number of regions in the batch
count = batch_size*num_proposals;
region_shape = np.array([img_height/3, img_width/3], np.int32)
region_images = np.zeros(shape=[count, region_shape[0],
region_shape[1], channels])
region_score = np.zeros(shape=[1,count])
partition = np.zeros(shape=[count])
question_encodings = np.zeros(shape=[count, len(vocab)])
for i in xrange(start_index, start_index + batch_size): def qa_index(self, start_index, batch_size):
return self.id_list[start_index - self.start_id
image_id = qa_dict[i].image_id : start_index - self.start_id + batch_size]
question = qa_dict[i].question
answer = qa_dict[i].answer def ans_mini_batch_loader(self, qa_dict, region_anno_dict, ans_dict, vocab,
gt_regions_for_image = region_anno_dict[image_id] image_dir, mean_image, start_index, batch_size,
start1 = time.time() img_height=100, img_width=100, channels = 3):
regions = region_proposer.rank_regions(None, question,
region_coords, region_coords_, q_ids = self.qa_index(start_index, batch_size)
gt_regions_for_image,
False) ans_labels = np.zeros(shape=[batch_size, len(ans_dict)])
for i in xrange(batch_size):
end1 = time.time() q_id = q_ids[i]
# print('Ranking Region: ' + str(end1-start1)) answer = qa_dict[q_id].answer
question_encoding_tmp = np.zeros(shape=[1, len(vocab)]) ans_labels[i, ans_dict[answer]] = 1
for word in question[0:-1].split(): # number of regions in the batch
count = batch_size*num_proposals;
region_shape = np.array([img_height/3, img_width/3], np.int32)
region_images = np.zeros(shape=[count, region_shape[0],
region_shape[1], channels])
region_score = np.zeros(shape=[1,count])
partition = np.zeros(shape=[count])
question_encodings = np.zeros(shape=[count, len(vocab)])
for i in xrange(batch_size):
q_id = q_ids[i]
image_id = qa_dict[q_id].image_id
question = qa_dict[q_id].question
answer = qa_dict[q_id].answer
gt_regions_for_image = region_anno_dict[image_id]
regions = region_proposer.rank_regions(None, question,
region_coords,
region_coords_,
gt_regions_for_image,
False)
question_encoding_tmp = np.zeros(shape=[1, len(vocab)])
for word in question[0:-1].split():
if word not in vocab: if word not in vocab:
word = 'unk' word = 'unk'
question_encoding_tmp[0, vocab[word]] += 1 question_encoding_tmp[0, vocab[word]] += 1
question_len = np.sum(question_encoding_tmp)
# print(question[0:-1].split()) question_len = np.sum(question_encoding_tmp)
# print(question_len) assert (not question_len==0)
# print(question_encoding_tmp) question_encoding_tmp /= question_len
# print(vocab)
assert (not question_len==0)
question_encoding_tmp /= question_len
for j in xrange(num_proposals): for j in xrange(num_proposals):
counter = j + (i-start_index)*num_proposals counter = j + i*num_proposals
proposal = regions[j]
proposal = regions[j] resized_region = mpimg.imread(os.path.join(image_dir,
'{}_{}.png'.format(image_id,j)))
start2 = time.time() region_images[counter,:,:,:] = (resized_region / 254.0) \
resized_region = mpimg.imread(os.path.join(image_dir, - mean_image
'{}_{}.png' region_score[0,counter] = proposal.score
.format(image_id,j))) partition[counter] = i
end2 = time.time()
# print('Reading Region: ' + str(end2-start2)) question_encodings[counter,:] = question_encoding_tmp
region_images[counter,:,:,:] = (resized_region / 254.0) \
- mean_image score_start_id = i*num_proposals
region_score[0,counter] = proposal.score region_score[0, score_start_id:score_start_id+num_proposals] /=\
partition[counter] = i-start_index np.sum(region_score[0,score_start_id
: score_start_id+num_proposals])
question_encodings[counter,:] = question_encoding_tmp return region_images, ans_labels, question_encodings, \
region_score, partition
score_start_id = (i-start_index)*num_proposals
region_score[0, score_start_id:score_start_id+num_proposals] /= \
np.sum(region_score[0,score_start_id:score_start_id+num_proposals]) class html_ans_table_writer():
return region_images, ans_labels, question_encodings, region_score, partition def __init__(self, filename):
self.filename = filename
self.html_file = open(self.filename, 'w')
self.html_file.write("""<!DOCTYPE html>\n<html>\n<body>\n<table border="1" style="width:100%"> \n""")
def add_element(self, col_dict):
self.html_file.write(' <tr>\n')
for key in range(len(col_dict)):
self.html_file.write(""" <td>{}</td>\n""".format(col_dict[key]))
self.html_file.write(' </tr>\n')
def image_tag(self, image_path, height, width):
return """<img src="{}" alt="IMAGE NOT FOUND!" height={} width={}>""".format(image_path,height,width)
def close_file(self):
self.html_file.write('</table>\n</body>\n</html>')
self.html_file.close()
if __name__=='__main__': if __name__=='__main__':
......
No preview for this file type
...@@ -12,35 +12,33 @@ import plot_helper as plotter ...@@ -12,35 +12,33 @@ import plot_helper as plotter
import ans_data_io_helper as ans_io_helper import ans_data_io_helper as ans_io_helper
import region_ranker.perfect_ranker as region_proposer import region_ranker.perfect_ranker as region_proposer
import train_ans_classifier as ans_trainer import train_ans_classifier as ans_trainer
from PIL import Image, ImageDraw
def get_pred(y, qa_anno_dict, region_anno_dict, ans_vocab, vocab, def get_pred(y, qa_anno_dict, region_anno_dict, ans_vocab, vocab,
image_dir, mean_image, start_index, val_set_size, batch_size, image_dir, mean_image, start_index, val_set_size, batch_size,
placeholders, img_height=100, img_width=100): placeholders, img_height, img_width, batch_creator):
inv_ans_vocab = {v: k for k, v in ans_vocab.items()} inv_ans_vocab = {v: k for k, v in ans_vocab.items()}
pred_list = [] pred_list = []
correct = 0 correct = 0
max_iter = int(math.ceil(val_set_size*1.0/batch_size)) max_iter = int(math.ceil(val_set_size*1.0/batch_size))
# print ([val_set_size, batch_size])
# print('max_iter: ' + str(max_iter))
batch_size_tmp = batch_size batch_size_tmp = batch_size
for i in xrange(max_iter): for i in xrange(max_iter):
if i==(max_iter-1): if i==(max_iter-1):
batch_size_tmp = val_set_size - i*batch_size batch_size_tmp = val_set_size - i*batch_size
print('Iter: ' + str(i+1) + '/' + str(max_iter)) print('Iter: ' + str(i+1) + '/' + str(max_iter))
# print batch_size_tmp
region_images, ans_labels, questions, \ region_images, ans_labels, questions, \
region_score, partition= \ region_score, partition = batch_creator \
ans_io_helper.ans_mini_batch_loader(qa_anno_dict, .ans_mini_batch_loader(qa_anno_dict,
region_anno_dict, region_anno_dict,
ans_vocab, vocab, ans_vocab, vocab,
image_dir, mean_image, image_dir, mean_image,
start_index+i*batch_size, start_index+i*batch_size,
batch_size_tmp, batch_size_tmp,
img_height, img_width, 3) img_height, img_width, 3)
# print [start_index+i*batch_size,
# start_index+i*batch_size + batch_size_tmp -1]
if i==max_iter-1: if i==max_iter-1:
residual_batch_size = batch_size - batch_size_tmp residual_batch_size = batch_size - batch_size_tmp
...@@ -63,10 +61,6 @@ def get_pred(y, qa_anno_dict, region_anno_dict, ans_vocab, vocab, ...@@ -63,10 +61,6 @@ def get_pred(y, qa_anno_dict, region_anno_dict, ans_vocab, vocab,
axis=0) axis=0)
region_score = np.concatenate((region_score, residual_region_score), region_score = np.concatenate((region_score, residual_region_score),
axis=1) axis=1)
# print region_images.shape
# print questions.shape
# print ans_labels.shape
# print region_score.shape
feed_dict = { feed_dict = {
placeholders[0] : region_images, placeholders[0] : region_images,
...@@ -82,8 +76,6 @@ def get_pred(y, qa_anno_dict, region_anno_dict, ans_vocab, vocab, ...@@ -82,8 +76,6 @@ def get_pred(y, qa_anno_dict, region_anno_dict, ans_vocab, vocab,
'question_id' : start_index+i*batch_size+j, 'question_id' : start_index+i*batch_size+j,
'answer' : inv_ans_vocab[ans_ids[j]] 'answer' : inv_ans_vocab[ans_ids[j]]
}] }]
# print qa_anno_dict[start_index+i*batch_size+j].question
# print inv_ans_vocab[ans_ids[j]]
return pred_list return pred_list
...@@ -143,10 +135,15 @@ def eval(eval_params): ...@@ -143,10 +135,15 @@ def eval(eval_params):
placeholders = [image_regions, questions, keep_prob, y, region_score] placeholders = [image_regions, questions, keep_prob, y, region_score]
# Batch creator
test_batch_creator = ans_io_helper.batch_creator(test_start_id,
test_start_id
+ test_set_size - 1)
# Get predictions # Get predictions
pred_dict =get_pred(y_avg, qa_anno_dict, region_anno_dict, ans_vocab, vocab, pred_dict = get_pred(y_avg, qa_anno_dict, region_anno_dict, ans_vocab, vocab,
image_regions_dir, mean_image, test_start_id, image_regions_dir, mean_image, test_start_id,
test_set_size, batch_size, placeholders, 75, 75) test_set_size, batch_size, placeholders, 75, 75,
test_batch_creator)
json_filename = os.path.join(outdir, 'predicted_ans_' + \ json_filename = os.path.join(outdir, 'predicted_ans_' + \
eval_params['mode'] + '.json') eval_params['mode'] + '.json')
...@@ -154,12 +151,89 @@ def eval(eval_params): ...@@ -154,12 +151,89 @@ def eval(eval_params):
json.dump(pred_dict, json_file) json.dump(pred_dict, json_file)
def create_html_file(outdir, test_anno_filename, regions_anno_filename,
pred_json_filename, image_dir):
qa_dict = ans_io_helper.parse_qa_anno(test_anno_filename)
region_anno_dict = region_proposer.parse_region_anno(regions_anno_filename)
ans_vocab, inv_ans_vocab = ans_io_helper.create_ans_dict()
with open(pred_json_filename,'r') as json_file:
raw_data = json.load(json_file)
# Create director for storing images with region boxes
images_bbox_dir = os.path.join(outdir, 'images_bbox')
if not os.path.exists(images_bbox_dir):
os.mkdir(images_bbox_dir)
col_dict = {
0 : 'Question_Id',
1 : 'Question',
2 : 'Answer (GT)',
3 : 'Answer (Pred)',
4 : 'Image',
}
html_correct_filename = os.path.join(outdir, 'correct_ans.html')
html_writer_correct = ans_io_helper \
.html_ans_table_writer(html_correct_filename)
html_writer_correct.add_element(col_dict)
html_incorrect_filename = os.path.join(outdir, 'incorrect_ans.html')
html_writer_incorrect = ans_io_helper \
.html_ans_table_writer(html_incorrect_filename)
html_writer_incorrect.add_element(col_dict)
region_coords, region_coords_ = region_proposer.get_region_coords(300,300)
for entry in raw_data:
q_id = entry['question_id']
pred_ans = entry['answer']
gt_ans = qa_dict[q_id].answer
question = qa_dict[q_id].question
img_id = qa_dict[q_id].image_id
image_filename = os.path.join(image_dir, str(img_id) + '.jpg')
image = Image.open(image_filename)
regions = region_proposer.rank_regions(image, question, region_coords,
region_coords_,
region_anno_dict[img_id],
crop=False)
dr = ImageDraw.Draw(image)
# print(q_id)
# print([regions[key].score for key in regions.keys()])
for i in xrange(ans_io_helper.num_proposals):
if not regions[i].score==0:
coord = regions[i].coord
x1 = coord[0]
y1 = coord[1]
x2 = coord[2]
y2 = coord[3]
dr.rectangle([(x1,y1),(x2,y2)], outline="red")
image_bbox_filename = os.path.join(images_bbox_dir,str(q_id) + '.jpg')
image.save(image_bbox_filename)
image_bbox_filename_rel = 'images_bbox/' + str(q_id) + '.jpg'
col_dict = {
0 : q_id,
1 : question,
2 : gt_ans,
3 : pred_ans,
4 : html_writer_correct.image_tag(image_bbox_filename_rel,50,50)
}
if pred_ans==gt_ans:
html_writer_correct.add_element(col_dict)
else:
html_writer_incorrect.add_element(col_dict)
html_writer_correct.close_file()
html_writer_incorrect.close_file()
if __name__=='__main__': if __name__=='__main__':
ans_classifier_eval_params = { ans_classifier_eval_params = {
'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json', 'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json', 'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
'regions_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/regions_anno.json', 'regions_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/regions_anno.json',
'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
'image_regions_dir': '/mnt/ramdisk/image_regions', 'image_regions_dir': '/mnt/ramdisk/image_regions',
'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier', 'outdir': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier',
'model': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier/ans_classifier_q_obj_atr-9', 'model': '/home/tanmay/Code/GenVQA/Exp_Results/Ans_Classifier/ans_classifier_q_obj_atr-9',
...@@ -169,4 +243,11 @@ if __name__=='__main__': ...@@ -169,4 +243,11 @@ if __name__=='__main__':
'test_set_size': 160725-111352+1, 'test_set_size': 160725-111352+1,
} }
eval(ans_classifier_eval_params) # eval(ans_classifier_eval_params)
outdir = ans_classifier_eval_params['outdir']
test_anno_filename = ans_classifier_eval_params['test_json']
regions_anno_filename = ans_classifier_eval_params['regions_json']
pred_json_filename = os.path.join(outdir, 'predicted_ans_q.json')
image_dir = ans_classifier_eval_params['image_dir']
create_html_file(outdir, test_anno_filename, regions_anno_filename,
pred_json_filename, image_dir)
...@@ -22,20 +22,17 @@ val_set_size_small = 100 ...@@ -22,20 +22,17 @@ val_set_size_small = 100
def evaluate(accuracy, qa_anno_dict, region_anno_dict, ans_vocab, vocab, def evaluate(accuracy, qa_anno_dict, region_anno_dict, ans_vocab, vocab,
image_dir, mean_image, start_index, val_set_size, batch_size, image_dir, mean_image, start_index, val_set_size, batch_size,
placeholders, img_height=100, img_width=100): placeholders, img_height, img_width, batch_creator):
correct = 0 correct = 0
max_iter = int(math.floor(val_set_size/batch_size)) max_iter = int(math.floor(val_set_size/batch_size))
for i in xrange(max_iter): for i in xrange(max_iter):
region_images, ans_labels, questions, \ region_images, ans_labels, questions, \
region_score, partition= \ region_score, partition= batch_creator \
ans_io_helper.ans_mini_batch_loader(qa_anno_dict, .ans_mini_batch_loader(qa_anno_dict, region_anno_dict,
region_anno_dict, ans_vocab, vocab, image_dir, mean_image,
ans_vocab, vocab, start_index+i*batch_size, batch_size,
image_dir, mean_image, img_height, img_width, 3)
start_index+i*batch_size,
batch_size,
img_height, img_width, 3)
feed_dict = { feed_dict = {
placeholders[0] : region_images, placeholders[0] : region_images,
...@@ -224,32 +221,43 @@ def train(train_params): ...@@ -224,32 +221,43 @@ def train(train_params):
placeholders = [image_regions, questions, keep_prob, y, region_score] placeholders = [image_regions, questions, keep_prob, y, region_score]
# Start Training
max_epoch = train_params['max_epoch']
max_iter = 5000
val_acc_array_epoch = np.zeros([max_epoch])
train_acc_array_epoch = np.zeros([max_epoch])
# Batch creators
train_batch_creator = ans_io_helper.batch_creator(1, max_iter*batch_size)
val_batch_creator = ans_io_helper.batch_creator(val_start_id, val_start_id
+ val_set_size - 1)
val_small_batch_creator = ans_io_helper.batch_creator(val_start_id,
val_start_id +
val_set_size_small-1)
# Check accuracy of restored model
if train_params['fine_tune']==True: if train_params['fine_tune']==True:
restored_accuracy = evaluate(accuracy, qa_anno_dict, restored_accuracy = evaluate(accuracy, qa_anno_dict,
region_anno_dict, ans_vocab, region_anno_dict, ans_vocab,
vocab, image_regions_dir, vocab, image_regions_dir,
mean_image, val_start_id, mean_image, val_start_id,
val_set_size, batch_size, val_set_size, batch_size,
placeholders, 75, 75) placeholders, 75, 75,
val_batch_creator)
print('Accuracy of restored model: ' + str(restored_accuracy)) print('Accuracy of restored model: ' + str(restored_accuracy))
# Start Training
max_epoch = train_params['max_epoch']
max_iter = 5000
val_acc_array_epoch = np.zeros([max_epoch])
train_acc_array_epoch = np.zeros([max_epoch])
for epoch in range(start_epoch, max_epoch): for epoch in range(start_epoch, max_epoch):
iter_ids = range(max_iter) train_batch_creator.shuffle_ids()
random.shuffle(iter_ids) for i in range(max_iter):
for i in iter_ids: #range(max_iter):
train_region_images, train_ans_labels, train_questions, \ train_region_images, train_ans_labels, train_questions, \
train_region_score, train_partition= \ train_region_score, train_partition= train_batch_creator \
ans_io_helper.ans_mini_batch_loader(qa_anno_dict, region_anno_dict, .ans_mini_batch_loader(qa_anno_dict, region_anno_dict,
ans_vocab, vocab, ans_vocab, vocab,
image_regions_dir, mean_image, image_regions_dir, mean_image,
1+i*batch_size, batch_size, 1+i*batch_size, batch_size,
75, 75, 3) 75, 75, 3)
feed_dict_train = { feed_dict_train = {
image_regions : train_region_images, image_regions : train_region_images,
...@@ -258,7 +266,7 @@ def train(train_params): ...@@ -258,7 +266,7 @@ def train(train_params):
y: train_ans_labels, y: train_ans_labels,
region_score: train_region_score, region_score: train_region_score,
} }
if pretrained_vars_low_lr: if pretrained_vars_low_lr:
_, _, current_train_batch_acc, y_pred_eval, loss_eval = \ _, _, current_train_batch_acc, y_pred_eval, loss_eval = \
sess.run([train_step_low_lr, train_step_high_lr, sess.run([train_step_low_lr, train_step_high_lr,
...@@ -280,10 +288,10 @@ def train(train_params): ...@@ -280,10 +288,10 @@ def train(train_params):
region_anno_dict, ans_vocab, vocab, region_anno_dict, ans_vocab, vocab,
image_regions_dir, mean_image, image_regions_dir, mean_image,
val_start_id, val_set_size_small, val_start_id, val_set_size_small,
batch_size, placeholders, 75, 75) batch_size, placeholders, 75, 75,
val_small_batch_creator)
print('Iter: ' + str(i+1) + ' Val Sm Acc: ' + str(val_accuracy) print('Iter: ' + str(i+1) + ' Val Sm Acc: ' + str(val_accuracy))
+ ' Loss: ' + str(loss_eval))
train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] / max_iter train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] / max_iter
val_acc_array_epoch[epoch] = evaluate(accuracy, qa_anno_dict, val_acc_array_epoch[epoch] = evaluate(accuracy, qa_anno_dict,
...@@ -291,7 +299,8 @@ def train(train_params): ...@@ -291,7 +299,8 @@ def train(train_params):
vocab, image_regions_dir, vocab, image_regions_dir,
mean_image, val_start_id, mean_image, val_start_id,
val_set_size, batch_size, val_set_size, batch_size,
placeholders, 75, 75) placeholders, 75, 75,
val_batch_creator)
print('Val Acc: ' + str(val_acc_array_epoch[epoch]) + print('Val Acc: ' + str(val_acc_array_epoch[epoch]) +
' Train Acc: ' + str(train_acc_array_epoch[epoch])) ' Train Acc: ' + str(train_acc_array_epoch[epoch]))
......
No preview for this file type
#Embedded file name: /home/tanmay/Code/GenVQA/GenVQA/classifiers/object_classifiers/obj_data_io_helper.py
import json
import sys
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import tensorflow as tf
from scipy import misc
def obj_mini_batch_loader(json_filename, image_dir, mean_image, start_index, batch_size, img_height = 100, img_width = 100, channels = 3):
with open(json_filename, 'r') as json_file:
json_data = json.load(json_file)
obj_images = np.empty(shape=[9 * batch_size, img_height / 3,
img_width / 3,
channels])
obj_labels = np.zeros(shape=[9 * batch_size, 4])
for i in range(start_index, start_index + batch_size):
image_name = os.path.join(image_dir, str(i) + '.jpg')
image = misc.imresize(mpimg.imread(image_name), (img_height, img_width), interp='nearest')
crop_shape = np.array([image.shape[0], image.shape[1]]) / 3
selected_anno = [ q for q in json_data if q['image_id'] == i ]
grid_config = selected_anno[0]['config']
counter = 0
for grid_row in range(0, 3):
for grid_col in range(0, 3):
start_row = grid_row * crop_shape[0]
start_col = grid_col * crop_shape[1]
cropped_image = image[start_row:start_row + crop_shape[0], start_col:start_col + crop_shape[1], :]
if np.ndim(mean_image) == 0:
obj_images[9 * (i - start_index) + counter, :, :, :] = cropped_image / 254.0
else:
obj_images[9 * (i - start_index) + counter, :, :, :] = (cropped_image / 254.0) - mean_image
obj_labels[9 * (i - start_index) + counter, grid_config[6 * grid_row + 2 * grid_col]] = 1
counter = counter + 1
return (obj_images, obj_labels)
def mean_image_batch(json_filename, image_dir, start_index, batch_size, img_height = 100, img_width = 100, channels = 3):
batch = obj_mini_batch_loader(json_filename, image_dir, np.empty([]), start_index, batch_size, img_height, img_width, channels)
mean_image = np.mean(batch[0], 0)
return mean_image
def mean_image(json_filename, image_dir, num_images, batch_size, img_height = 100, img_width = 100, channels = 3):
max_iter = np.floor(num_images / batch_size)
mean_image = np.zeros([img_height / 3, img_width / 3, channels])
for i in range(max_iter.astype(np.int16)):
mean_image = mean_image + mean_image_batch(json_filename, image_dir, 1 + i * batch_size, batch_size, img_height, img_width, channels)
mean_image = mean_image / max_iter
tmp_mean_image = mean_image * 254
return mean_image
class html_obj_table_writer:
def __init__(self, filename):
self.filename = filename
self.html_file = open(self.filename, 'w')
self.html_file.write('<!DOCTYPE html>\n<html>\n<body>\n<table border="1" style="width:100%"> \n')
def add_element(self, col_dict):
self.html_file.write(' <tr>\n')
for key in range(len(col_dict)):
self.html_file.write(' <td>{}</td>\n'.format(col_dict[key]))
self.html_file.write(' </tr>\n')
def image_tag(self, image_path, height, width):
return '<img src="{}" alt="IMAGE NOT FOUND!" height={} width={}>'.format(image_path, height, width)
def close_file(self):
self.html_file.write('</table>\n</body>\n</html>')
self.html_file.close()
if __name__ == '__main__':
html_writer = html_obj_table_writer('/home/tanmay/Code/GenVQA/Exp_Results/Shape_Classifier_v_1/trial.html')
col_dict = {0: 'sam',
1: html_writer.image_tag('something.png', 25, 25)}
html_writer.add_element(col_dict)
html_writer.close_file()
...@@ -8,7 +8,6 @@ import matplotlib.image as mpimg ...@@ -8,7 +8,6 @@ import matplotlib.image as mpimg
from scipy import misc from scipy import misc
region = namedtuple('region','image score coord') region = namedtuple('region','image score coord')
def parse_region_anno(json_filename): def parse_region_anno(json_filename):
with open(json_filename,'r') as json_file: with open(json_filename,'r') as json_file:
raw_data = json.load(json_file) raw_data = json.load(json_file)
...@@ -19,7 +18,6 @@ def parse_region_anno(json_filename): ...@@ -19,7 +18,6 @@ def parse_region_anno(json_filename):
return region_anno_dict return region_anno_dict
def get_region_coords(img_height, img_width): def get_region_coords(img_height, img_width):
region_coords_ = np.array([[ 1, 1, 100, 100], region_coords_ = np.array([[ 1, 1, 100, 100],
[ 101, 1, 200, 100], [ 101, 1, 200, 100],
......
No preview for this file type
...@@ -282,8 +282,7 @@ def loss(y, y_pred): ...@@ -282,8 +282,7 @@ def loss(y, y_pred):
cross_entropy = -tf.reduce_sum(y * tf.log(y_pred_clipped), cross_entropy = -tf.reduce_sum(y * tf.log(y_pred_clipped),
name='cross_entropy') name='cross_entropy')
batch_size = tf.shape(y) batch_size = tf.shape(y)
print 'Batch Size:' + str(tf.cast(batch_size[0],tf.float32)) return tf.truediv(cross_entropy, tf.cast(batch_size[0],tf.float32))
return tf.truediv(cross_entropy, tf.cast(20,tf.float32))#batch_size[0],tf.float32))
if __name__ == '__main__': if __name__ == '__main__':
......
No preview for this file type
...@@ -73,8 +73,8 @@ ans_classifier_train_params = { ...@@ -73,8 +73,8 @@ ans_classifier_train_params = {
'crop_n_save_regions': False, 'crop_n_save_regions': False,
'max_epoch': 10, 'max_epoch': 10,
'batch_size': 20, 'batch_size': 20,
'fine_tune': False, 'fine_tune': True,
'start_model': 9, 'start_model': 3,
} }
if __name__=='__main__': if __name__=='__main__':
......
import sys
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import tensorflow as tf
import obj_data_io_helper as shape_data_loader
from train_obj_classifier import placeholder_inputs, comp_graph_v_1, evaluation
sess=tf.InteractiveSession()
x, y, keep_prob = placeholder_inputs()
y_pred = comp_graph_v_1(x, y, keep_prob)
accuracy = evaluation(y, y_pred)
saver = tf.train.Saver()
saver.restore(sess, '/home/tanmay/Code/GenVQA/Exp_Results/Shape_Classifier_v_1/obj_classifier_9.ckpt')
mean_image = np.load('/home/tanmay/Code/GenVQA/Exp_Results/Shape_Classifier_v_1/mean_image.npy')
# Test Data
test_json_filename = '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json'
image_dir = '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images'
# HTML file writer
html_writer = shape_data_loader.html_obj_table_writer('/home/tanmay/Code/GenVQA/Exp_Results/Shape_Classifier_v_1/trial.html')
batch_size = 100
correct = 0
for i in range(1): #50
test_batch = shape_data_loader.obj_mini_batch_loader(test_json_filename, image_dir, mean_image, 10000+i*batch_size, batch_size, 75, 75)
feed_dict_test={x: test_batch[0], y: test_batch[1], keep_prob: 1.0}
result = sess.run([accuracy, y_pred], feed_dict=feed_dict_test)
correct = correct + result[0]*batch_size
print(correct)
for row in range(batch_size):
col_dict = {
0: test_batch[1][row,:],
1: y_pred[row, :]}
html_writer.add_element(col_dict)
html_writer.close_file()
print('Test Accuracy: {}'.format(correct/5000))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment