Skip to content
Snippets Groups Projects
Commit d47d305b authored by tgupta6's avatar tgupta6
Browse files

Cleaned up commented lines

parent a3527aad
No related branches found
No related tags found
No related merge requests found
......@@ -3,16 +3,12 @@ import json
import os
import pdb
import time
from multiprocessing import Pool
import threading
import tftools.data
import image_io
import constants
import Queue
#import thread
import threading
import tensorflow as tf
_unknown_token = constants.unknown_token
......@@ -70,7 +66,6 @@ class data():
[batch_size, len(self.object_labels_dict)], np.float32)
batch['attribute_labels'] = np.zeros(
[batch_size, len(self.attribute_labels_dict)], np.float32)
for index, sample in enumerate(samples):
batch['region_ids'][index] = self.sample_to_region_dict[sample]
batch['images'][index, :, :, :], read_success = \
......@@ -79,16 +74,10 @@ class data():
batch['object_labels'][index, :] = self.get_object_label(sample)
batch['attribute_labels'][index,:] = \
self.get_attribute_label(sample)
return batch
def get_single(self, sample, result, index):
def get_single(self, sample, batch_list, worker_id):
try:
# sample, lock = sample_lock
# lock.acquire()
# print sample
# self.num_threads += 1
# lock.release()
batch = dict()
batch['region_ids'] = dict()
batch['images'] = np.zeros(
......@@ -104,50 +93,27 @@ class data():
batch['object_labels'] = self.get_object_label(sample)
batch['attribute_labels'] = self.get_attribute_label(sample)
result[index] = batch
except Exception, e:
print str(e)
print threading.current_thread().name + ' errored'
# return batch
batch_list[worker_id] = batch
except Exception, e:
print 'Error in thread {}: {}'.format(
threading.current_thread().name, str(e))
def get_parallel(self, samples):
# batch_list = self.pool.map(self.get_single, samples)
self.num_threads = 0
q = Queue.Queue()
batch_list = [None]*len(samples)
count = 0
indices = range(len(samples))
jobs = []
for sample in samples:
worker = threading.Thread(target=self.get_single, args = (sample, batch_list, indices[count]))
worker_ids = range(len(samples))
workers = []
for count, sample in enumerate(samples):
worker = threading.Thread(
target = self.get_single,
args = (sample, batch_list, worker_ids[count]))
worker.setDaemon(True)
worker.start()
jobs.append(worker)
count += 1
workers.append(worker)
for worker in jobs:
worker.join()
# while threading.active_count() > 0:
# print '#Threads: {}'.format(threading.active_count())
# pass
# lock = thread.allocate_lock()
# for sample in samples:
# thread.start_new_thread(self.get_single, (sample,lock,q))
# while (self.num_threads > 4):
# pass
# while (self.num_threads > 0):
# pass
# thread.exit()
# batch_list = []
# while not q.empty():
# batch_list.append(q.get())
# pool = Pool(3)
# batch_list = pool.map(unwrap_self_get_single, zip([self]*len(samples),samples))
batch_size = len(samples)
batch = dict()
batch['region_ids'] = dict()
......@@ -280,12 +246,8 @@ if __name__=='__main__':
print "Attributes: {}".format(", ".join(attributes))
print "Objects: {}".format(", ".join(objects))
# pdb.set_trace()
data_mgr.regions['5581152']
# image_io.imshow(region_image)
batch_size = 200
num_samples = 10000
num_samples = 200
num_epochs = 1
offset = 0
......@@ -295,28 +257,14 @@ if __name__=='__main__':
num_epochs,
offset)
# start = time.time()
# count = 0
# for samples in index_generator:
# batch = data_mgr.get(samples)
# print 'Batch Count: {}'.format(count)
# count += 1
# stop = time.time()
# print 'Time per batch: {}'.format((stop-start)/5.0)
batch_generator = tftools.data.async_batch_generator(
data_mgr,
index_generator,
100)
count = 0
start = time.time()
for batch in batch_generator:
print 'Batch Number: {}'.format(count)
# print batch['region_ids']
count += 1
stop = time.time()
print 'Time per batch: {}'.format((stop-start)/50.0)
print "Count: {}".format(count)
......@@ -26,7 +26,10 @@ def batch_generator(data, index_generator, batch_function=None):
"""Generate batches of data.
"""
for samples in index_generator:
# start = time.time()
batch = data.get_parallel(samples)
# stop = time.time()
# print 'Time: {}'.format(stop-start)
if batch_function:
output = batch_function(batch)
else:
......@@ -43,7 +46,7 @@ def async_batch_generator(data, index_generator, queue_maxsize, batch_function=N
fetcher = BatchFetcher(queue, batcher)
fetcher.start()
time.sleep(1)
# time.sleep(10)
queue_batcher = queue_generator(queue)
return queue_batcher
......@@ -74,20 +77,6 @@ class BatchFetcher(Process):
# Signal that fetcher is done.
self.queue.put(None)
class BatchFetcherParallel(Process):
def __init__(self, queue, batch_generator):
super(BatchFetcher, self).__init__()
self.queue = queue
self.batch_generator = batch_generator
def run(self):
for batch in self.batch_generator:
self.queue.put(batch)
# Signal that fetcher is done.
self.queue.put(None)
class NumpyData(object):
def __init__(self, array):
......@@ -107,30 +96,6 @@ class NumpyData(object):
# else:
# queue.put(None)
# def create_generator(queue, index_generator, data):
# for samples in index_generator:
# yield (queue, samples, data)
# def parallel_async_batch_generator(data, index_generator, queue_maxsize, batch_function=None):
# pool = Pool(processes = 3)
# m = Manager()
# queue = m.Queue(maxsize=queue_maxsize)
# gen = create_generator(queue, index_generator, data)
# count =0
# for gen_sample in gen:
# gen_list = []
# while(count<3):
# gen_list.append(gen_sample)
# count += 1
# pool.map(write_to_queue, gen_list)
# workers = pool.map(write_to_queue, (queue,[],data))
# print 'Here Already'
# # enqueue_process = Process(target = pool.map, args=(write_to_queue, gen))
# queue_batcher = queue_generator(queue)
# return queue_batcher
if __name__=='__main__':
batch_size = 10
......
......@@ -27,6 +27,7 @@ _unknown_token = 'UNK'
_unopenable_images = 'unopenable_images.json'
_im_w = 224
_im_h = 224
_pool_size = 10
if not os.path.exists(_outdir):
os.mkdir(_outdir)
......@@ -265,14 +266,7 @@ def top_k_attribute_labels(k):
json.dump(attribute_labels, outfile, sort_keys=True, indent=4)
existing_region = 0
not_existing_region = 0
unopenable_images_list = []
unopenable_images_filename = os.path.join(_datadir, _unopenable_images)
def crop_region(region_info):
global existing_region
global not_existing_region
global unopenable_images_list
region_id, region_data = region_info
image_filename = os.path.join(_datadir,
'images/' +
......@@ -282,21 +276,15 @@ def crop_region(region_info):
image_out_filename = os.path.join(image_subdir,
str(region_id) + '.jpg')
if os.path.exists(image_out_filename):
existing_region += 1
return
else:
not_existing_region += 1
# print region_info, existing_region, not_existing_region
if not os.path.exists(image_subdir):
os.mkdir(image_subdir)
try:
image = image_io.imread(image_filename)
except:
print image_filename
unopenable_images_list.append(image_filename)
print 'Could not read image: {}'.format(image_filename)
return
if len(image.shape)==3:
......@@ -319,9 +307,6 @@ def crop_region(region_info):
def crop_regions_parallel():
global unopenable_images_filename
global unopenable_images_list
global existing_region
regions_filename = os.path.join(_outdir,
_regions_with_labels)
with open(regions_filename) as file:
......@@ -330,21 +315,18 @@ def crop_regions_parallel():
if not os.path.exists(_cropped_regions_dir):
os.mkdir(_cropped_regions_dir)
for region_id, region_data in regions.items():
crop_region((region_id, region_data))
print existing_region
# pool = Pool(10)
# try:
# pool.map(crop_region, regions.items())
# except:
# pool.close()
# raise
# pool.close()
pool = Pool(_pool_size)
try:
pool.map(crop_region, regions.items())
except:
pool.close()
raise
pool.close()
with open(unopenable_images_filename, 'w') as outfile:
json.dump(unopenable_images_list, outfile, sort_keys=True, indent=4)
def crop_regions():
regions_filename = os.path.join(_outdir,
_regions_with_labels)
......@@ -405,4 +387,4 @@ if __name__=='__main__':
# normalize_region_object_attribute_labels()
# top_k_object_labels(1000)
# top_k_attribute_labels(1000)
crop_regions_parallel()
# crop_regions_parallel()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment