Skip to content
Snippets Groups Projects
Commit a3527aad authored by tgupta6's avatar tgupta6
Browse files

multithreaded implementation of batch creation working

parent fe505614
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,7 @@ unknown_token = 'UNK'
# Data paths
data_absolute_path = '/home/tanmay/Data/VisualGenome'
image_dir = os.path.join(data_absolute_path, 'images')
image_dir = os.path.join(data_absolute_path, 'cropped_regions')
object_labels_json = os.path.join(
data_absolute_path,
'restructured/object_labels.json')
......
......@@ -9,11 +9,17 @@ from multiprocessing import Pool
import tftools.data
import image_io
import constants
import Queue
#import thread
import threading
import tensorflow as tf
_unknown_token = constants.unknown_token
def unwrap_self_get_single(arg, **kwarg):
return data.get_single(*arg, **kwarg)
class data():
def __init__(self,
image_dir,
......@@ -39,6 +45,7 @@ class data():
self.regions = self.read_json_file(regions_json)
self.num_regions = len(self.regions)
self.create_sample_to_region_dict()
self.num_threads = 0
def create_sample_to_region_dict(self):
self.sample_to_region_dict = \
......@@ -66,33 +73,81 @@ class data():
for index, sample in enumerate(samples):
batch['region_ids'][index] = self.sample_to_region_dict[sample]
batch['images'][index, :, :, :] = self.get_region_image(sample)
batch['object_labels'][index, :] = self.get_object_label(sample)
batch['attribute_labels'][index,:] = \
self.get_attribute_label(sample)
batch['images'][index, :, :, :], read_success = \
self.get_region_image(sample)
if read_success:
batch['object_labels'][index, :] = self.get_object_label(sample)
batch['attribute_labels'][index,:] = \
self.get_attribute_label(sample)
return batch
def get_single(self, sample):
print sample
batch = dict()
batch['region_ids'] = dict()
batch['images'] = np.zeros(
def get_single(self, sample, result, index):
try:
# sample, lock = sample_lock
# lock.acquire()
# print sample
# self.num_threads += 1
# lock.release()
batch = dict()
batch['region_ids'] = dict()
batch['images'] = np.zeros(
[self.h, self.w, self.c], np.float32)
batch['object_labels'] = np.zeros(
[len(self.object_labels_dict)], np.float32)
batch['attribute_labels'] = np.zeros(
[len(self.attribute_labels_dict)], np.float32)
batch['object_labels'] = np.zeros(
[len(self.object_labels_dict)], np.float32)
batch['attribute_labels'] = np.zeros(
[len(self.attribute_labels_dict)], np.float32)
batch['region_ids'] = self.sample_to_region_dict[sample]
batch['images'], read_success = self.get_region_image(sample)
if read_success:
batch['object_labels'] = self.get_object_label(sample)
batch['attribute_labels'] = self.get_attribute_label(sample)
batch['region_ids'] = self.sample_to_region_dict[sample]
batch['images'] = self.get_region_image(sample)
batch['object_labels'] = self.get_object_label(sample)
batch['attribute_labels'] = self.get_attribute_label(sample)
result[index] = batch
except Exception, e:
print str(e)
print threading.current_thread().name + ' errored'
# return batch
return batch
def get_parallel(self, samples):
batch_list = self.pool.map(self.get_single, samples)
# batch_list = self.pool.map(self.get_single, samples)
self.num_threads = 0
q = Queue.Queue()
batch_list = [None]*len(samples)
count = 0
indices = range(len(samples))
jobs = []
for sample in samples:
worker = threading.Thread(target=self.get_single, args = (sample, batch_list, indices[count]))
worker.setDaemon(True)
worker.start()
jobs.append(worker)
count += 1
for worker in jobs:
worker.join()
# while threading.active_count() > 0:
# print '#Threads: {}'.format(threading.active_count())
# pass
# lock = thread.allocate_lock()
# for sample in samples:
# thread.start_new_thread(self.get_single, (sample,lock,q))
# while (self.num_threads > 4):
# pass
# while (self.num_threads > 0):
# pass
# thread.exit()
# batch_list = []
# while not q.empty():
# batch_list.append(q.get())
# pool = Pool(3)
# batch_list = pool.map(unwrap_self_get_single, zip([self]*len(samples),samples))
batch_size = len(samples)
batch = dict()
batch['region_ids'] = dict()
......@@ -109,24 +164,25 @@ class data():
batch['object_labels'][index, :] = single_batch['object_labels']
batch['attribute_labels'][index,:] = single_batch['attribute_labels']
return batch
def get_region_image(self, sample):
region_id = self.sample_to_region_dict[sample]
region = self.regions[region_id]
filename = os.path.join(self.image_dir,
str(region['image_id']) + '.jpg')
image = image_io.imread(filename)
image = self.single_to_three_channel(image)
x, y, w, h = self.get_clipped_region_coords(region, image.shape[0:2])
region_image = image[y:y + h, x:x + w, :]
region_image = image_io.imresize(
region_image,
output_size=(self.h, self.w)).astype(np.float32)
image_subdir = os.path.join(self.image_dir,
str(region['image_id']))
filename = os.path.join(image_subdir,
str(region_id) + '.jpg')
read_success = True
try:
region_image = image_io.imread(filename)
except:
read_success = False
region_image = np.zeros([self.h, self.w], dtype)
region_image = region_image.astype(np.float32)
return region_image / 255 - self.mean_image
return region_image / 255 - self.mean_image, read_success
def single_to_three_channel(self, image):
if len(image.shape)==3:
......@@ -223,11 +279,13 @@ if __name__=='__main__':
print "Region: {}".format(region)
print "Attributes: {}".format(", ".join(attributes))
print "Objects: {}".format(", ".join(objects))
# pdb.set_trace()
data_mgr.regions['5581152']
# image_io.imshow(region_image)
batch_size = 200
num_samples = 1000
num_samples = 10000
num_epochs = 1
offset = 0
......@@ -249,7 +307,7 @@ if __name__=='__main__':
batch_generator = tftools.data.async_batch_generator(
data_mgr,
index_generator,
1000)
100)
count = 0
start = time.time()
......@@ -261,4 +319,4 @@ if __name__=='__main__':
print 'Time per batch: {}'.format((stop-start)/50.0)
print "Count: {}".format(count)
pool.close()
"""Utility functions for making batch generation easier."""
import numpy as np
from multiprocessing import Process, Queue
from multiprocessing import Process, Queue, Manager, Pool
import time
def sequential(batch_size, num_samples, num_epochs=1, offset=0):
......@@ -26,7 +26,7 @@ def batch_generator(data, index_generator, batch_function=None):
"""Generate batches of data.
"""
for samples in index_generator:
batch = data.get(samples)
batch = data.get_parallel(samples)
if batch_function:
output = batch_function(batch)
else:
......@@ -43,7 +43,7 @@ def async_batch_generator(data, index_generator, queue_maxsize, batch_function=N
fetcher = BatchFetcher(queue, batcher)
fetcher.start()
time.sleep(10)
time.sleep(1)
queue_batcher = queue_generator(queue)
return queue_batcher
......@@ -74,6 +74,20 @@ class BatchFetcher(Process):
# Signal that fetcher is done.
self.queue.put(None)
class BatchFetcherParallel(Process):
def __init__(self, queue, batch_generator):
super(BatchFetcher, self).__init__()
self.queue = queue
self.batch_generator = batch_generator
def run(self):
for batch in self.batch_generator:
self.queue.put(batch)
# Signal that fetcher is done.
self.queue.put(None)
class NumpyData(object):
def __init__(self, array):
......@@ -85,6 +99,39 @@ class NumpyData(object):
return self.array[indices]
# def write_to_queue((queue, samples, data)):
# print samples
# if samples:
# batch = data.get(samples)
# queue.put(batch)
# else:
# queue.put(None)
# def create_generator(queue, index_generator, data):
# for samples in index_generator:
# yield (queue, samples, data)
# def parallel_async_batch_generator(data, index_generator, queue_maxsize, batch_function=None):
# pool = Pool(processes = 3)
# m = Manager()
# queue = m.Queue(maxsize=queue_maxsize)
# gen = create_generator(queue, index_generator, data)
# count =0
# for gen_sample in gen:
# gen_list = []
# while(count<3):
# gen_list.append(gen_sample)
# count += 1
# pool.map(write_to_queue, gen_list)
# workers = pool.map(write_to_queue, (queue,[],data))
# print 'Here Already'
# # enqueue_process = Process(target = pool.map, args=(write_to_queue, gen))
# queue_batcher = queue_generator(queue)
# return queue_batcher
if __name__=='__main__':
batch_size = 10
num_samples = 20
......
......@@ -24,6 +24,7 @@ _object_labels = 'object_labels.json'
_attribute_labels = 'attribute_labels.json'
_regions_with_labels = 'region_with_labels.json'
_unknown_token = 'UNK'
_unopenable_images = 'unopenable_images.json'
_im_w = 224
_im_h = 224
......@@ -264,12 +265,39 @@ def top_k_attribute_labels(k):
json.dump(attribute_labels, outfile, sort_keys=True, indent=4)
existing_region = 0
not_existing_region = 0
unopenable_images_list = []
unopenable_images_filename = os.path.join(_datadir, _unopenable_images)
def crop_region(region_info):
global existing_region
global not_existing_region
global unopenable_images_list
region_id, region_data = region_info
image_filename = os.path.join(_datadir,
'images/' +
str(region_data['image_id']) + '.jpg')
image = image_io.imread(image_filename)
image_subdir = os.path.join(_cropped_regions_dir,
str(region_data['image_id']))
image_out_filename = os.path.join(image_subdir,
str(region_id) + '.jpg')
if os.path.exists(image_out_filename):
existing_region += 1
return
else:
not_existing_region += 1
# print region_info, existing_region, not_existing_region
if not os.path.exists(image_subdir):
os.mkdir(image_subdir)
try:
image = image_io.imread(image_filename)
except:
print image_filename
unopenable_images_list.append(image_filename)
return
if len(image.shape)==3:
im_h, im_w, im_c =image.shape
......@@ -287,18 +315,13 @@ def crop_region(region_info):
cropped_region = image_io.imresize(image[y:y+h,x:x+w,:],
output_size=(_im_h, _im_w))
image_subdir = os.path.join(_cropped_regions_dir,
str(region_data['image_id']))
if not os.path.exists(image_subdir):
os.mkdir(image_subdir)
image_out_filename = os.path.join(image_subdir,
str(region_id) + '.jpg')
image_io.imwrite(cropped_region, image_out_filename)
def crop_regions_parallel():
global unopenable_images_filename
global unopenable_images_list
global existing_region
regions_filename = os.path.join(_outdir,
_regions_with_labels)
with open(regions_filename) as file:
......@@ -307,13 +330,20 @@ def crop_regions_parallel():
if not os.path.exists(_cropped_regions_dir):
os.mkdir(_cropped_regions_dir)
pool = Pool(24)
try:
pool.map(crop_region, regions.items())
except:
pool.close()
pool.close()
for region_id, region_data in regions.items():
crop_region((region_id, region_data))
print existing_region
# pool = Pool(10)
# try:
# pool.map(crop_region, regions.items())
# except:
# pool.close()
# raise
# pool.close()
with open(unopenable_images_filename, 'w') as outfile:
json.dump(unopenable_images_list, outfile, sort_keys=True, indent=4)
def crop_regions():
regions_filename = os.path.join(_outdir,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment