From d47d305ba5dff51b00dc7a13bc9ecd7df2d075dc Mon Sep 17 00:00:00 2001 From: tgupta6 <tgupta6@illinois.edu> Date: Mon, 6 Jun 2016 15:46:42 -0500 Subject: [PATCH] Cleaned up commented lines --- data/cropped_regions.py | 80 ++++++++--------------------------------- tftools/data.py | 43 +++------------------- visual_genome_parser.py | 40 ++++++--------------- 3 files changed, 29 insertions(+), 134 deletions(-) diff --git a/data/cropped_regions.py b/data/cropped_regions.py index f32f6d4..174b22b 100644 --- a/data/cropped_regions.py +++ b/data/cropped_regions.py @@ -3,16 +3,12 @@ import json import os import pdb import time - -from multiprocessing import Pool +import threading import tftools.data import image_io import constants -import Queue -#import thread -import threading import tensorflow as tf _unknown_token = constants.unknown_token @@ -70,7 +66,6 @@ class data(): [batch_size, len(self.object_labels_dict)], np.float32) batch['attribute_labels'] = np.zeros( [batch_size, len(self.attribute_labels_dict)], np.float32) - for index, sample in enumerate(samples): batch['region_ids'][index] = self.sample_to_region_dict[sample] batch['images'][index, :, :, :], read_success = \ @@ -79,16 +74,10 @@ class data(): batch['object_labels'][index, :] = self.get_object_label(sample) batch['attribute_labels'][index,:] = \ self.get_attribute_label(sample) - return batch - def get_single(self, sample, result, index): + def get_single(self, sample, batch_list, worker_id): try: -# sample, lock = sample_lock -# lock.acquire() -# print sample -# self.num_threads += 1 -# lock.release() batch = dict() batch['region_ids'] = dict() batch['images'] = np.zeros( @@ -104,50 +93,27 @@ class data(): batch['object_labels'] = self.get_object_label(sample) batch['attribute_labels'] = self.get_attribute_label(sample) - result[index] = batch - except Exception, e: - print str(e) - print threading.current_thread().name + ' errored' - -# return batch + batch_list[worker_id] = batch + except Exception, e: + print 'Error in thread {}: {}'.format( + threading.current_thread().name, str(e)) def get_parallel(self, samples): -# batch_list = self.pool.map(self.get_single, samples) - self.num_threads = 0 - q = Queue.Queue() batch_list = [None]*len(samples) - count = 0 - indices = range(len(samples)) - jobs = [] - for sample in samples: - worker = threading.Thread(target=self.get_single, args = (sample, batch_list, indices[count])) + worker_ids = range(len(samples)) + workers = [] + for count, sample in enumerate(samples): + worker = threading.Thread( + target = self.get_single, + args = (sample, batch_list, worker_ids[count])) worker.setDaemon(True) worker.start() - jobs.append(worker) - count += 1 + workers.append(worker) for worker in jobs: worker.join() - # while threading.active_count() > 0: - # print '#Threads: {}'.format(threading.active_count()) - # pass - # lock = thread.allocate_lock() - # for sample in samples: - # thread.start_new_thread(self.get_single, (sample,lock,q)) - # while (self.num_threads > 4): - # pass - # while (self.num_threads > 0): - # pass - # thread.exit() - - # batch_list = [] - # while not q.empty(): - # batch_list.append(q.get()) -# pool = Pool(3) -# batch_list = pool.map(unwrap_self_get_single, zip([self]*len(samples),samples)) - batch_size = len(samples) batch = dict() batch['region_ids'] = dict() @@ -280,12 +246,8 @@ if __name__=='__main__': print "Attributes: {}".format(", ".join(attributes)) print "Objects: {}".format(", ".join(objects)) -# pdb.set_trace() - data_mgr.regions['5581152'] -# image_io.imshow(region_image) - batch_size = 200 - num_samples = 10000 + num_samples = 200 num_epochs = 1 offset = 0 @@ -295,28 +257,14 @@ if __name__=='__main__': num_epochs, offset) - # start = time.time() - # count = 0 - # for samples in index_generator: - # batch = data_mgr.get(samples) - # print 'Batch Count: {}'.format(count) - # count += 1 - # stop = time.time() - # print 'Time per batch: {}'.format((stop-start)/5.0) - batch_generator = tftools.data.async_batch_generator( data_mgr, index_generator, 100) count = 0 - start = time.time() for batch in batch_generator: print 'Batch Number: {}'.format(count) -# print batch['region_ids'] count += 1 - stop = time.time() - print 'Time per batch: {}'.format((stop-start)/50.0) - print "Count: {}".format(count) diff --git a/tftools/data.py b/tftools/data.py index 5c00c8a..150cda6 100644 --- a/tftools/data.py +++ b/tftools/data.py @@ -26,7 +26,10 @@ def batch_generator(data, index_generator, batch_function=None): """Generate batches of data. """ for samples in index_generator: +# start = time.time() batch = data.get_parallel(samples) +# stop = time.time() +# print 'Time: {}'.format(stop-start) if batch_function: output = batch_function(batch) else: @@ -43,7 +46,7 @@ def async_batch_generator(data, index_generator, queue_maxsize, batch_function=N fetcher = BatchFetcher(queue, batcher) fetcher.start() - time.sleep(1) +# time.sleep(10) queue_batcher = queue_generator(queue) return queue_batcher @@ -74,20 +77,6 @@ class BatchFetcher(Process): # Signal that fetcher is done. self.queue.put(None) -class BatchFetcherParallel(Process): - def __init__(self, queue, batch_generator): - super(BatchFetcher, self).__init__() - self.queue = queue - self.batch_generator = batch_generator - - def run(self): - - for batch in self.batch_generator: - self.queue.put(batch) - - # Signal that fetcher is done. - self.queue.put(None) - class NumpyData(object): def __init__(self, array): @@ -107,30 +96,6 @@ class NumpyData(object): # else: # queue.put(None) -# def create_generator(queue, index_generator, data): -# for samples in index_generator: -# yield (queue, samples, data) - -# def parallel_async_batch_generator(data, index_generator, queue_maxsize, batch_function=None): -# pool = Pool(processes = 3) -# m = Manager() -# queue = m.Queue(maxsize=queue_maxsize) -# gen = create_generator(queue, index_generator, data) -# count =0 -# for gen_sample in gen: -# gen_list = [] -# while(count<3): -# gen_list.append(gen_sample) -# count += 1 -# pool.map(write_to_queue, gen_list) -# workers = pool.map(write_to_queue, (queue,[],data)) -# print 'Here Already' -# # enqueue_process = Process(target = pool.map, args=(write_to_queue, gen)) -# queue_batcher = queue_generator(queue) -# return queue_batcher - - - if __name__=='__main__': batch_size = 10 diff --git a/visual_genome_parser.py b/visual_genome_parser.py index 8847228..5a57cf9 100644 --- a/visual_genome_parser.py +++ b/visual_genome_parser.py @@ -27,6 +27,7 @@ _unknown_token = 'UNK' _unopenable_images = 'unopenable_images.json' _im_w = 224 _im_h = 224 +_pool_size = 10 if not os.path.exists(_outdir): os.mkdir(_outdir) @@ -265,14 +266,7 @@ def top_k_attribute_labels(k): json.dump(attribute_labels, outfile, sort_keys=True, indent=4) -existing_region = 0 -not_existing_region = 0 -unopenable_images_list = [] -unopenable_images_filename = os.path.join(_datadir, _unopenable_images) def crop_region(region_info): - global existing_region - global not_existing_region - global unopenable_images_list region_id, region_data = region_info image_filename = os.path.join(_datadir, 'images/' + @@ -282,21 +276,15 @@ def crop_region(region_info): image_out_filename = os.path.join(image_subdir, str(region_id) + '.jpg') if os.path.exists(image_out_filename): - existing_region += 1 return - else: - not_existing_region += 1 -# print region_info, existing_region, not_existing_region if not os.path.exists(image_subdir): os.mkdir(image_subdir) - try: image = image_io.imread(image_filename) except: - print image_filename - unopenable_images_list.append(image_filename) + print 'Could not read image: {}'.format(image_filename) return if len(image.shape)==3: @@ -319,9 +307,6 @@ def crop_region(region_info): def crop_regions_parallel(): - global unopenable_images_filename - global unopenable_images_list - global existing_region regions_filename = os.path.join(_outdir, _regions_with_labels) with open(regions_filename) as file: @@ -330,21 +315,18 @@ def crop_regions_parallel(): if not os.path.exists(_cropped_regions_dir): os.mkdir(_cropped_regions_dir) - for region_id, region_data in regions.items(): - crop_region((region_id, region_data)) - - print existing_region - # pool = Pool(10) - # try: - # pool.map(crop_region, regions.items()) - # except: - # pool.close() - # raise - # pool.close() + pool = Pool(_pool_size) + try: + pool.map(crop_region, regions.items()) + except: + pool.close() + raise + pool.close() with open(unopenable_images_filename, 'w') as outfile: json.dump(unopenable_images_list, outfile, sort_keys=True, indent=4) + def crop_regions(): regions_filename = os.path.join(_outdir, _regions_with_labels) @@ -405,4 +387,4 @@ if __name__=='__main__': # normalize_region_object_attribute_labels() # top_k_object_labels(1000) # top_k_attribute_labels(1000) - crop_regions_parallel() + # crop_regions_parallel() -- GitLab