Skip to content
Snippets Groups Projects
Commit 39219523 authored by tgupta6's avatar tgupta6
Browse files

classifiers/object_classifiers/obj_data_io_helper.py

parent b9351f74
No related branches found
No related tags found
No related merge requests found
#Embedded file name: /home/tanmay/Code/GenVQA/GenVQA/classifiers/object_classifiers/obj_data_io_helper.py
import json
import sys
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import tensorflow as tf
from scipy import misc
def obj_mini_batch_loader(json_filename, image_dir, mean_image, start_index, batch_size, img_height = 100, img_width = 100, channels = 3):
with open(json_filename, 'r') as json_file:
json_data = json.load(json_file)
obj_images = np.empty(shape=[9 * batch_size,
img_height / 3,
img_width / 3,
channels])
obj_labels = np.zeros(shape=[9 * batch_size, 4])
for i in range(start_index, start_index + batch_size):
image_name = os.path.join(image_dir, str(i) + '.jpg')
image = misc.imresize(mpimg.imread(image_name), (img_height, img_width), interp='nearest')
crop_shape = np.array([image.shape[0], image.shape[1]]) / 3
selected_anno = [ q for q in json_data if q['image_id'] == i ]
grid_config = selected_anno[0]['config']
counter = 0
for grid_row in range(0, 3):
for grid_col in range(0, 3):
start_row = grid_row * crop_shape[0]
start_col = grid_col * crop_shape[1]
cropped_image = image[start_row:start_row + crop_shape[0], start_col:start_col + crop_shape[1], :]
if np.ndim(mean_image) == 0:
obj_images[9 * (i - start_index) + counter, :, :, :] = cropped_image / 254.0
else:
obj_images[9 * (i - start_index) + counter, :, :, :] = (cropped_image - mean_image) / 254
obj_labels[9 * (i - start_index) + counter, grid_config[6 * grid_row + 2 * grid_col]] = 1
counter = counter + 1
return (obj_images, obj_labels)
def mean_image_batch(json_filename, image_dir, start_index, batch_size, img_height = 100, img_width = 100, channels = 3):
batch = obj_mini_batch_loader(json_filename, image_dir, np.empty([]), start_index, batch_size, img_height, img_width, channels)
mean_image = np.mean(batch[0], 0)
return mean_image
def mean_image(json_filename, image_dir, num_images, batch_size, img_height = 100, img_width = 100, channels = 3):
max_iter = np.floor(num_images / batch_size)
mean_image = np.zeros([img_height / 3, img_width / 3, channels])
for i in range(max_iter.astype(np.int16)):
mean_image = mean_image + mean_image_batch(json_filename, image_dir, 1 + i * batch_size, batch_size, img_height, img_width, channels)
mean_image = mean_image / max_iter
tmp_mean_image = mean_image * 254
return mean_image
class html_obj_table_writer:
def __init__(self, filename):
self.filename = filename
self.html_file = open(self.filename, 'w')
self.html_file.write('<!DOCTYPE html>\n<html>\n<body>\n<table border="1" style="width:100%"> \n')
def add_element(self, col_dict):
self.html_file.write(' <tr>\n')
for key in range(len(col_dict)):
self.html_file.write(' <td>{}</td>\n'.format(col_dict[key]))
self.html_file.write(' </tr>\n')
def image_tag(self, image_path, height, width):
return '<img src="{}" alt="IMAGE NOT FOUND!" height={} width={}>'.format(image_path, height, width)
def close_file(self):
self.html_file.write('</table>\n</body>\n</html>')
self.html_file.close()
if __name__ == '__main__':
html_writer = html_obj_table_writer('/home/tanmay/Code/GenVQA/Exp_Results/Shape_Classifier_v_1/trial.html')
col_dict = {0: 'sam',
1: html_writer.image_tag('something.png', 25, 25)}
html_writer.add_element(col_dict)
html_writer.close_file()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment