Skip to content
Snippets Groups Projects
Commit 573d0b72 authored by tgupta6's avatar tgupta6
Browse files

Working code for atr (color) classifier

parent 2e53b8fb
Branches object_classifiers
No related tags found
No related merge requests found
......@@ -42,7 +42,7 @@ def atr_mini_batch_loader(json_filename, image_dir, mean_image, start_index, bat
return (atr_images, atr_labels)
def mean_image_batch(json_filename, image_dir, start_index, batch_size, img_height=100, img_width=100, channels=3):
batch = obj_mini_batch_loader(json_filename, image_dir, np.empty([]), start_index, batch_size, img_height, img_width, channels)
batch = atr_mini_batch_loader(json_filename, image_dir, np.empty([]), start_index, batch_size, img_height, img_width, channels)
mean_image = np.mean(batch[0], 0)
return mean_image
......
File added
......@@ -5,33 +5,33 @@ import matplotlib.image as mpimg
import numpy as np
from scipy import misc
import tensorflow as tf
import obj_data_io_helper as shape_data_loader
from train_obj_classifier import placeholder_inputs, comp_graph_v_1, evaluation
import atr_data_io_helper as atr_data_loader
from train_atr_classifier import placeholder_inputs, comp_graph_v_2, evaluation
sess=tf.InteractiveSession()
x, y, keep_prob = placeholder_inputs()
y_pred = comp_graph_v_1(x, y, keep_prob)
y_pred = comp_graph_v_2(x, y, keep_prob)
accuracy = evaluation(y, y_pred)
saver = tf.train.Saver()
saver.restore(sess, '/home/tanmay/Code/GenVQA/Exp_Results/Shape_Classifier_v_1/obj_classifier_9.ckpt')
saver.restore(sess, '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier_v_1/obj_classifier_1.ckpt')
mean_image = np.load('/home/tanmay/Code/GenVQA/Exp_Results/Shape_Classifier_v_1/mean_image.npy')
mean_image = np.load('/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier_v_1/mean_image.npy')
# Test Data
test_json_filename = '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/test_anno.json'
image_dir = '/home/tanmay/Code/GenVQA/GenVQA/shapes_dataset/images'
# Base dir for html visualizer
html_dir = '/home/tanmay/Code/GenVQA/Exp_Results/Shape_Classifier_v_1/html'
html_dir = '/home/tanmay/Code/GenVQA/Exp_Results/Atr_Classifier_v_1/html'
if not os.path.exists(html_dir):
os.mkdir(html_dir)
# HTML file writer
html_writer = shape_data_loader.html_obj_table_writer(os.path.join(html_dir,'index.html'))
html_writer = atr_data_loader.html_atr_table_writer(os.path.join(html_dir,'index.html'))
col_dict={
0: 'Grount Truth',
1: 'Prediction',
......@@ -47,7 +47,7 @@ shape_dict = {
batch_size = 100
correct = 0
for i in range(50):
test_batch = shape_data_loader.obj_mini_batch_loader(test_json_filename, image_dir, mean_image, 10000+i*batch_size, batch_size, 75, 75)
test_batch = atr_data_loader.atr_mini_batch_loader(test_json_filename, image_dir, mean_image, 10000+i*batch_size, batch_size, 75, 75)
feed_dict_test={x: test_batch[0], y: test_batch[1], keep_prob: 1.0}
result = sess.run([accuracy, y_pred], feed_dict=feed_dict_test)
correct = correct + result[0]*batch_size
......
......@@ -89,24 +89,23 @@ def comp_graph_v_1(x, y, keep_prob):
def comp_graph_v_2(x, y, keep_prob):
# Specify computation graph
W_conv1 = weight_variable([5, 5, 3, 10])
b_conv1 = bias_variable([10])
W_conv1 = weight_variable([5, 5, 3, 4])
b_conv1 = bias_variable([4])
h_conv1 = tf.nn.relu(conv2d(x, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
h_conv1_drop = tf.nn.dropout(h_pool1, keep_prob)
W_conv2 = weight_variable([5, 5, 10, 20])
b_conv2 = bias_variable([20])
W_conv2 = weight_variable([3, 3, 4, 8])
b_conv2 = bias_variable([8])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
h_conv2_drop = tf.nn.dropout(h_pool2, keep_prob)
W_fc1 = weight_variable([7*7*20, 4])
W_fc1 = weight_variable([7*7*8, 4])
b_fc1 = bias_variable([4])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*20])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*8])
h_pool2_flat_drop = tf.nn.dropout(h_pool2_flat, keep_prob)
y_pred = tf.nn.softmax(tf.matmul(h_pool2_flat_drop,W_fc1) + b_fc1)
......
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment