Skip to content
Snippets Groups Projects
Commit 92f02b56 authored by tgupta6's avatar tgupta6
Browse files

optimized atr training and evaluation

parent 99ac4c72
No related branches found
No related tags found
No related merge requests found
......@@ -6,47 +6,49 @@ import matplotlib.image as mpimg
import numpy as np
import tensorflow as tf
from scipy import misc
import time
def atr_mini_batch_loader(json_filename, image_dir, mean_image, start_index, batch_size, img_height=100, img_width=100, channels=3):
with open(json_filename, 'r') as json_file:
json_data = json.load(json_file)
def atr_mini_batch_loader(json_data, image_dir, mean_image, start_index, batch_size, img_height=100, img_width=100, channels=3):
atr_images = np.empty(shape=[9*batch_size, img_height/3, img_width/3, channels])
atr_labels = np.zeros(shape=[9*batch_size, 4])
for i in range(start_index, start_index + batch_size):
for i in xrange(start_index, start_index + batch_size):
image_name = os.path.join(image_dir, str(i) + '.jpg')
image = misc.imresize(mpimg.imread(image_name),(img_height, img_width), interp='nearest')
#image = np.zeros(shape=[img_height, img_width, channels])
crop_shape = np.array([image.shape[0], image.shape[1]])/3
selected_anno = [q for q in json_data if q['image_id']==i]
grid_config = selected_anno[0]['config']
grid_config = json_data[i]
counter = 0;
for grid_row in range(0,3):
for grid_col in range(0,3):
for grid_row in xrange(0,3):
for grid_col in xrange(0,3):
start_row = grid_row*crop_shape[0]
start_col = grid_col*crop_shape[1]
cropped_image = image[start_row:start_row+crop_shape[0], start_col:start_col+crop_shape[1], :]
if np.ndim(mean_image)==0:
atr_images[9*(i-start_index)+counter,:,:,:] = cropped_image/254.0
else:
atr_images[9*(i-start_index)+counter,:,:,:] = (cropped_image / 254.0) - mean_image
atr_labels[9*(i-start_index)+counter, grid_config[6*grid_row+2*grid_col+1]] = 1
if grid_config[6*grid_row+2*grid_col]==0: # if the space is blank
atr_labels[9*(i-start_index)+counter,3] = 1
else:
atr_labels[9*(i-start_index)+counter, grid_config[6*grid_row+2*grid_col+1]] = 1
counter = counter + 1
return (atr_images, atr_labels)
def mean_image_batch(json_filename, image_dir, start_index, batch_size, img_height=100, img_width=100, channels=3):
batch = atr_mini_batch_loader(json_filename, image_dir, np.empty([]), start_index, batch_size, img_height, img_width, channels)
def mean_image_batch(json_data, image_dir, start_index, batch_size, img_height=100, img_width=100, channels=3):
batch = atr_mini_batch_loader(json_data, image_dir, np.empty([]), start_index, batch_size, img_height, img_width, channels)
mean_image = np.mean(batch[0], 0)
return mean_image
def mean_image(json_filename, image_dir, num_images, batch_size, img_height=100, img_width=100, channels=3):
def mean_image(json_data, image_dir, num_images, batch_size, img_height=100, img_width=100, channels=3):
max_iter = np.floor(num_images/batch_size)
mean_image = np.zeros([img_height/3, img_width/3, channels])
for i in range(max_iter.astype(np.int16)):
mean_image = mean_image + mean_image_batch(json_filename, image_dir, 1+i*batch_size, batch_size, img_height, img_width, channels)
mean_image = mean_image + mean_image_batch(json_data, image_dir, 1+i*batch_size, batch_size, img_height, img_width, channels)
mean_image = mean_image/max_iter
return mean_image
......
No preview for this file type
import sys
import json
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
......@@ -15,6 +16,7 @@ def eval(eval_params):
x, y, keep_prob = graph_creator.placeholder_inputs()
_ = graph_creator.obj_comp_graph(x, 1.0)
obj_feat = tf.get_collection('obj_feat', scope='obj/conv2')
y_pred = graph_creator.atr_comp_graph(x, keep_prob, obj_feat[0])
accuracy = graph_creator.evaluation(y, y_pred)
......@@ -23,6 +25,13 @@ def eval(eval_params):
mean_image = np.load(os.path.join(eval_params['out_dir'], 'mean_image.npy'))
test_json_filename = eval_params['test_json']
with open(test_json_filename, 'r') as json_file:
raw_json_data = json.load(json_file)
test_json_data = dict()
for entry in raw_json_data:
if entry['image_id'] not in test_json_data:
test_json_data[entry['image_id']]=entry['config']
image_dir = eval_params['image_dir']
html_dir = eval_params['html_dir']
if not os.path.exists(html_dir):
......@@ -36,12 +45,13 @@ def eval(eval_params):
color_dict = {
0: 'red', # blanks are treated as red
1: 'green',
2: 'blue'}
2: 'blue',
3: 'blank'}
batch_size = 100
correct = 0
for i in range(50):
test_batch = atr_data_loader.atr_mini_batch_loader(test_json_filename, image_dir, mean_image, 10000+i*batch_size, batch_size, 75, 75)
test_batch = atr_data_loader.atr_mini_batch_loader(test_json_data, image_dir, mean_image, 10000+i*batch_size, batch_size, 75, 75)
feed_dict_test={x: test_batch[0], y: test_batch[1], keep_prob: 1.0}
result = sess.run([accuracy, y_pred], feed_dict=feed_dict_test)
correct = correct + result[0]*batch_size
......
No preview for this file type
import sys
import json
import os
import time
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
......@@ -14,15 +16,30 @@ def train(train_params):
x, y, keep_prob = graph_creator.placeholder_inputs()
_ = graph_creator.obj_comp_graph(x, 1.0)
obj_feat = tf.get_collection('obj_feat', scope='obj/conv2')
# Session Saver
obj_saver = tf.train.Saver()
# Restore obj network parameters
obj_saver.restore(sess, train_params['obj_model_name'] + '-' + str(train_params['obj_global_step']))
y_pred = graph_creator.atr_comp_graph(x, keep_prob, obj_feat[0])
cross_entropy = graph_creator.loss(y, y_pred)
vars_to_opt = tf.get_collection(tf.GraphKeys.VARIABLES, scope='atr')
vars_to_restore = tf.get_collection(tf.GraphKeys.VARIABLES, scope='obj')
train_step = tf.train.AdamOptimizer(train_params['adam_lr']).minimize(cross_entropy, var_list=vars_to_opt)
all_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
vars_to_init = [var for var in all_vars if var not in vars_to_restore]
print('Variables that are being optimized: ' + ' '.join([var.name for var in vars_to_opt]))
print('All variables: ' + ' '.join([var.name for var in all_vars]))
train_step = tf.train.AdamOptimizer(train_params['adam_lr']).minimize(cross_entropy, var_list=vars_to_opt)
print('Variables that will be initialized: ' + ' '.join([var.name for var in vars_to_init]))
accuracy = graph_creator.evaluation(y, y_pred)
# Session saver for atr variables
atr_saver = tf.train.Saver(vars_to_opt)
obj_atr_saver = tf.train.Saver(all_vars)
outdir = train_params['out_dir']
if not os.path.exists(outdir):
os.mkdir(outdir)
......@@ -31,23 +48,32 @@ def train(train_params):
img_width = 75
img_height = 75
train_json_filename = train_params['train_json']
with open(train_json_filename, 'r') as json_file:
raw_json_data = json.load(json_file)
train_json_data = dict()
for entry in raw_json_data:
if entry['image_id'] not in train_json_data:
train_json_data[entry['image_id']]=entry['config']
image_dir = train_params['image_dir']
mean_image = atr_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width)
if train_params['mean_image']=='':
print('Computing mean image')
mean_image = atr_data_loader.mean_image(train_json_data, image_dir, 1000, 100, img_height, img_width)
else:
print('Loading mean image')
mean_image = np.load(train_params['mean_image'])
np.save(os.path.join(outdir, 'mean_image.npy'), mean_image)
# Val Data
val_batch = atr_data_loader.atr_mini_batch_loader(train_json_filename, image_dir, mean_image, 9501, 499, img_height, img_width)
print('Loading validation data')
val_batch = atr_data_loader.atr_mini_batch_loader(train_json_data, image_dir, mean_image, 9501, 499, img_height, img_width)
feed_dict_val={x: val_batch[0], y: val_batch[1], keep_prob: 1.0}
# Session Saver
saver = tf.train.Saver()
# Start Training
sess.run(tf.initialize_all_variables())
# Restore obj network parameters
saver.restore(sess, train_params['obj_model_name'] + '-' + str(train_params['obj_global_step']))
print('Initializing variables')
sess.run(tf.initialize_variables(vars_to_init))
batch_size = 10
max_epoch = 2
......@@ -57,18 +83,23 @@ def train(train_params):
train_acc_array_epoch = np.zeros([max_epoch])
for epoch in range(max_epoch):
for i in range(max_iter):
train_batch = atr_data_loader.atr_mini_batch_loader(train_json_filename, image_dir, mean_image, 1+i*batch_size, batch_size, img_height, img_width)
print(i)
# start = time.time()
train_batch = atr_data_loader.atr_mini_batch_loader(train_json_data, image_dir, mean_image, 1+i*batch_size, batch_size, img_height, img_width)
# end = time.time()
# print(end-start)
feed_dict_train={x: train_batch[0], y: train_batch[1], keep_prob: 0.5}
_, current_train_batch_acc = sess.run([train_step, accuracy], feed_dict=feed_dict_train)
train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] + current_train_batch_acc
val_acc_array_iter[i+epoch*max_iter] = accuracy.eval(feed_dict_val)
plotter.plot_accuracy(np.arange(0,i+1+epoch*max_iter)+1, val_acc_array_iter[0:i+1+epoch*max_iter], xlim=[1, max_epoch*max_iter], ylim=[0, 1.], savePath=os.path.join(outdir,'valAcc_vs_iter.pdf'))
print('Step: {} Val Accuracy: {}'.format(i+1+epoch*max_iter, val_acc_array_iter[i+epoch*max_iter]))
# val_acc_array_iter[i+epoch*max_iter] = accuracy.eval(feed_dict_val)
# plotter.plot_accuracy(np.arange(0,i+1+epoch*max_iter)+1, val_acc_array_iter[0:i+1+epoch*max_iter], xlim=[1, max_epoch*max_iter], ylim=[0, 1.], savePath=os.path.join(outdir,'valAcc_vs_iter.pdf'))
# print('Step: {} Val Accuracy: {}'.format(i+1+epoch*max_iter, val_acc_array_iter[i+epoch*max_iter]))
train_acc_array_epoch[epoch] = train_acc_array_epoch[epoch] / max_iter
val_acc_array_epoch[epoch] = val_acc_array_iter[i+epoch*max_iter]
plotter.plot_accuracies(xdata=np.arange(0,epoch+1)+1, ydata_train=train_acc_array_epoch[0:epoch+1], ydata_val=val_acc_array_epoch[0:epoch+1], xlim=[1, max_epoch], ylim=[0, 1.], savePath=os.path.join(outdir,'acc_vs_epoch.pdf'))
save_path = saver.save(sess, os.path.join(outdir,'atr_classifier'), global_step=epoch)
_ = atr_saver.save(sess, os.path.join(outdir,'atr_classifier'), global_step=epoch)
_ = obj_atr_saver.save(sess, os.path.join(outdir,'obj_atr_classifier'), global_step=epoch)
sess.close()
......
No preview for this file type
......@@ -26,7 +26,12 @@ def train(train_params):
img_height = 75
train_json_filename = train_params['train_json']
image_dir = train_params['image_dir']
mean_image = shape_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width)
if train_params['mean_image']=='':
print('Computing mean image')
mean_image = atr_data_loader.mean_image(train_json_filename, image_dir, 1000, 100, img_height, img_width)
else:
print('Loading mean image')
mean_image = np.load(train_params['mean_image'])
np.save(os.path.join(outdir, 'mean_image.npy'), mean_image)
# Val Data
......
No preview for this file type
No preview for this file type
......@@ -12,7 +12,7 @@ import attribute_classifiers.eval_atr_classifier as atr_evaluator
workflow = {
'train_obj': False,
'eval_obj': False,
'train_atr': True,
'train_atr': False,
'eval_atr': True,
}
......@@ -21,6 +21,8 @@ obj_classifier_train_params = {
'adam_lr': 0.001,
'train_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/train_anno.json',
'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
'mean_image': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier/mean_image.npy',
# 'mean_image': '',
}
obj_classifier_eval_params = {
......@@ -39,11 +41,13 @@ atr_classifier_train_params = {
'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
'obj_model_name': obj_classifier_eval_params['model_name'],
'obj_global_step': 1,
'mean_image': '/home/tanmay/Code/GenVQA/Exp_Results/Obj_Classifier/mean_image.npy',
# 'mean_image': '',
}
atr_classifier_eval_params = {
'out_dir': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier',
'model_name': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/atr_classifier',
'model_name': '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier/obj_atr_classifier',
'global_step': 1,
'test_json': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json',
'image_dir': '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment