From d47d305ba5dff51b00dc7a13bc9ecd7df2d075dc Mon Sep 17 00:00:00 2001
From: tgupta6 <tgupta6@illinois.edu>
Date: Mon, 6 Jun 2016 15:46:42 -0500
Subject: [PATCH] Cleaned up commented lines

---
 data/cropped_regions.py | 80 ++++++++---------------------------------
 tftools/data.py         | 43 +++-------------------
 visual_genome_parser.py | 40 ++++++---------------
 3 files changed, 29 insertions(+), 134 deletions(-)

diff --git a/data/cropped_regions.py b/data/cropped_regions.py
index f32f6d4..174b22b 100644
--- a/data/cropped_regions.py
+++ b/data/cropped_regions.py
@@ -3,16 +3,12 @@ import json
 import os
 import pdb
 import time
-
-from multiprocessing import Pool
+import threading
 
 import tftools.data 
 import image_io
 import constants
-import Queue
 
-#import thread
-import threading
 import tensorflow as tf
 
 _unknown_token = constants.unknown_token
@@ -70,7 +66,6 @@ class data():
             [batch_size, len(self.object_labels_dict)], np.float32)
         batch['attribute_labels'] = np.zeros(
             [batch_size, len(self.attribute_labels_dict)], np.float32)
-
         for index, sample in enumerate(samples):
             batch['region_ids'][index] = self.sample_to_region_dict[sample]
             batch['images'][index, :, :, :], read_success = \
@@ -79,16 +74,10 @@ class data():
                 batch['object_labels'][index, :] = self.get_object_label(sample)
                 batch['attribute_labels'][index,:] = \
                     self.get_attribute_label(sample)
-
         return batch
 
-    def get_single(self, sample, result, index):
+    def get_single(self, sample, batch_list, worker_id):
         try:
-#        sample, lock = sample_lock
-#        lock.acquire()
-#        print sample
-#        self.num_threads += 1
-#        lock.release()
             batch = dict()
             batch['region_ids'] = dict()
             batch['images'] = np.zeros(
@@ -104,50 +93,27 @@ class data():
                 batch['object_labels'] = self.get_object_label(sample)
                 batch['attribute_labels'] = self.get_attribute_label(sample)
 
-            result[index] = batch
-        except Exception, e:
-            print str(e)
-            print threading.current_thread().name + ' errored'
-
-#        return batch
+            batch_list[worker_id] = batch
 
+        except Exception, e:
+            print 'Error in thread {}: {}'.format(
+                threading.current_thread().name, str(e))
 
     def get_parallel(self, samples):
-#        batch_list = self.pool.map(self.get_single, samples)
-        self.num_threads = 0
-        q = Queue.Queue()
         batch_list = [None]*len(samples)
-        count = 0
-        indices = range(len(samples))
-        jobs = []
-        for sample in samples:
-            worker = threading.Thread(target=self.get_single, args = (sample, batch_list, indices[count]))
+        worker_ids = range(len(samples))
+        workers = []
+        for count, sample in enumerate(samples):
+            worker = threading.Thread(
+                target = self.get_single, 
+                args = (sample, batch_list, worker_ids[count]))
             worker.setDaemon(True)
             worker.start()
-            jobs.append(worker)
-            count += 1
+            workers.append(worker)
         
         for worker in jobs:
             worker.join()
 
-         # while threading.active_count() > 0:
-         #    print '#Threads: {}'.format(threading.active_count())
-         #    pass
-        # lock = thread.allocate_lock()
-        # for sample in samples:
-        #     thread.start_new_thread(self.get_single, (sample,lock,q))
-        #     while (self.num_threads > 4):
-        #         pass
-        # while (self.num_threads > 0):
-        #     pass
-        # thread.exit()
-
-        # batch_list = []
-        # while not q.empty():
-        #     batch_list.append(q.get())
-#        pool = Pool(3)
-#        batch_list = pool.map(unwrap_self_get_single, zip([self]*len(samples),samples))
-        
         batch_size = len(samples)
         batch = dict()
         batch['region_ids'] = dict()
@@ -280,12 +246,8 @@ if __name__=='__main__':
     print "Attributes: {}".format(", ".join(attributes))
     print "Objects: {}".format(", ".join(objects))
 
-#    pdb.set_trace()
-    data_mgr.regions['5581152']
-#    image_io.imshow(region_image)
-
     batch_size = 200
-    num_samples = 10000
+    num_samples = 200
     num_epochs = 1
     offset = 0
 
@@ -295,28 +257,14 @@ if __name__=='__main__':
         num_epochs, 
         offset)
     
-    # start = time.time()
-    # count = 0
-    # for samples in index_generator:
-    #     batch = data_mgr.get(samples)
-    #     print 'Batch Count: {}'.format(count)
-    #     count += 1
-    # stop = time.time()
-    # print 'Time per batch: {}'.format((stop-start)/5.0)
-
     batch_generator = tftools.data.async_batch_generator(
         data_mgr, 
         index_generator, 
         100)
 
     count = 0 
-    start = time.time()
     for batch in batch_generator:
         print 'Batch Number: {}'.format(count)
-#        print batch['region_ids']
         count += 1
-    stop = time.time()
-    print 'Time per batch: {}'.format((stop-start)/50.0)
-    print "Count: {}".format(count)
 
     
diff --git a/tftools/data.py b/tftools/data.py
index 5c00c8a..150cda6 100644
--- a/tftools/data.py
+++ b/tftools/data.py
@@ -26,7 +26,10 @@ def batch_generator(data, index_generator, batch_function=None):
     """Generate batches of data.
     """
     for samples in index_generator:
+#        start = time.time()
         batch = data.get_parallel(samples)
+#        stop = time.time()
+#        print 'Time: {}'.format(stop-start)
         if batch_function:
             output = batch_function(batch)
         else:
@@ -43,7 +46,7 @@ def async_batch_generator(data, index_generator, queue_maxsize, batch_function=N
     fetcher = BatchFetcher(queue, batcher)
     fetcher.start()
 
-    time.sleep(1)
+#    time.sleep(10)
     
     queue_batcher = queue_generator(queue)
     return queue_batcher
@@ -74,20 +77,6 @@ class BatchFetcher(Process):
         # Signal that fetcher is done.
         self.queue.put(None)
 
-class BatchFetcherParallel(Process):
-    def __init__(self, queue, batch_generator):
-        super(BatchFetcher, self).__init__()
-        self.queue = queue
-        self.batch_generator = batch_generator
-
-    def run(self):
-        
-        for batch in self.batch_generator:
-            self.queue.put(batch)
-
-        # Signal that fetcher is done.
-        self.queue.put(None)
-
 
 class NumpyData(object):
     def __init__(self, array):
@@ -107,30 +96,6 @@ class NumpyData(object):
 #     else:
 #         queue.put(None)
 
-# def create_generator(queue, index_generator, data):
-#     for samples in index_generator:
-#         yield (queue, samples, data)
-
-# def parallel_async_batch_generator(data, index_generator, queue_maxsize, batch_function=None):
-#     pool = Pool(processes = 3)
-#     m = Manager()
-#     queue = m.Queue(maxsize=queue_maxsize)
-#     gen = create_generator(queue, index_generator, data)
-#     count =0
-#     for gen_sample in gen:
-#         gen_list = []
-#         while(count<3):
-#             gen_list.append(gen_sample)
-#             count += 1
-#         pool.map(write_to_queue, gen_list)        
-#     workers = pool.map(write_to_queue, (queue,[],data))
-#     print 'Here Already'
-# #    enqueue_process = Process(target = pool.map, args=(write_to_queue, gen))
-#     queue_batcher = queue_generator(queue)
-#     return queue_batcher
-    
-
-    
 
 if __name__=='__main__':
     batch_size = 10
diff --git a/visual_genome_parser.py b/visual_genome_parser.py
index 8847228..5a57cf9 100644
--- a/visual_genome_parser.py
+++ b/visual_genome_parser.py
@@ -27,6 +27,7 @@ _unknown_token = 'UNK'
 _unopenable_images = 'unopenable_images.json'
 _im_w = 224
 _im_h = 224
+_pool_size = 10
 
 if not os.path.exists(_outdir):
     os.mkdir(_outdir)
@@ -265,14 +266,7 @@ def top_k_attribute_labels(k):
         json.dump(attribute_labels, outfile, sort_keys=True, indent=4)
  
 
-existing_region = 0
-not_existing_region = 0
-unopenable_images_list = []
-unopenable_images_filename = os.path.join(_datadir, _unopenable_images)
 def crop_region(region_info):
-    global existing_region 
-    global not_existing_region 
-    global unopenable_images_list
     region_id, region_data = region_info
     image_filename = os.path.join(_datadir, 
                                   'images/' + 
@@ -282,21 +276,15 @@ def crop_region(region_info):
     image_out_filename = os.path.join(image_subdir, 
                                       str(region_id) + '.jpg')
     if os.path.exists(image_out_filename):
-        existing_region += 1
         return
-    else:
-        not_existing_region += 1
-#        print region_info, existing_region, not_existing_region
 
     if not os.path.exists(image_subdir):
         os.mkdir(image_subdir)
-        
 
     try:
         image = image_io.imread(image_filename)
     except:
-        print image_filename
-        unopenable_images_list.append(image_filename)
+        print 'Could not read image: {}'.format(image_filename)
         return
     
     if len(image.shape)==3:
@@ -319,9 +307,6 @@ def crop_region(region_info):
 
 
 def crop_regions_parallel():
-    global unopenable_images_filename
-    global unopenable_images_list
-    global existing_region
     regions_filename = os.path.join(_outdir,
                                     _regions_with_labels)
     with open(regions_filename) as file:
@@ -330,21 +315,18 @@ def crop_regions_parallel():
     if not os.path.exists(_cropped_regions_dir):
         os.mkdir(_cropped_regions_dir)
 
-    for region_id, region_data in regions.items():
-        crop_region((region_id, region_data))
-
-    print existing_region
-    # pool = Pool(10)
-    # try:
-    #     pool.map(crop_region, regions.items())
-    # except:
-    #     pool.close()
-    #     raise
-    # pool.close()
+    pool = Pool(_pool_size)
+    try:
+        pool.map(crop_region, regions.items())
+    except:
+        pool.close()
+        raise
+    pool.close()
 
     with open(unopenable_images_filename, 'w') as outfile:
         json.dump(unopenable_images_list, outfile, sort_keys=True, indent=4)
 
+
 def crop_regions():
     regions_filename = os.path.join(_outdir,
                                     _regions_with_labels)
@@ -405,4 +387,4 @@ if __name__=='__main__':
     # normalize_region_object_attribute_labels()
     # top_k_object_labels(1000)
     # top_k_attribute_labels(1000)
-    crop_regions_parallel()
+    # crop_regions_parallel()
-- 
GitLab