From a3527aada52b52aedde199386bb7bae44fc72c13 Mon Sep 17 00:00:00 2001
From: tgupta6 <tgupta6@illinois.edu>
Date: Mon, 6 Jun 2016 13:04:10 -0500
Subject: [PATCH] multithreaded implementation of batch creation working

---
 constants.py            |   2 +-
 data/cropped_regions.py | 132 +++++++++++++++++++++++++++++-----------
 tftools/data.py         |  53 +++++++++++++++-
 visual_genome_parser.py |  62 ++++++++++++++-----
 4 files changed, 192 insertions(+), 57 deletions(-)

diff --git a/constants.py b/constants.py
index b292f0f..adbc94a 100644
--- a/constants.py
+++ b/constants.py
@@ -8,7 +8,7 @@ unknown_token = 'UNK'
 
 # Data paths
 data_absolute_path = '/home/tanmay/Data/VisualGenome'
-image_dir = os.path.join(data_absolute_path, 'images')
+image_dir = os.path.join(data_absolute_path, 'cropped_regions')
 object_labels_json = os.path.join(
     data_absolute_path,
     'restructured/object_labels.json')
diff --git a/data/cropped_regions.py b/data/cropped_regions.py
index b99e6c0..f32f6d4 100644
--- a/data/cropped_regions.py
+++ b/data/cropped_regions.py
@@ -9,11 +9,17 @@ from multiprocessing import Pool
 import tftools.data 
 import image_io
 import constants
+import Queue
 
+#import thread
+import threading
 import tensorflow as tf
 
 _unknown_token = constants.unknown_token
 
+def unwrap_self_get_single(arg, **kwarg):
+    return data.get_single(*arg, **kwarg)
+
 class data():
     def __init__(self,
                  image_dir,
@@ -39,6 +45,7 @@ class data():
         self.regions = self.read_json_file(regions_json)
         self.num_regions = len(self.regions)
         self.create_sample_to_region_dict()
+        self.num_threads = 0
 
     def create_sample_to_region_dict(self):
         self.sample_to_region_dict = \
@@ -66,33 +73,81 @@ class data():
 
         for index, sample in enumerate(samples):
             batch['region_ids'][index] = self.sample_to_region_dict[sample]
-            batch['images'][index, :, :, :] = self.get_region_image(sample)
-            batch['object_labels'][index, :] = self.get_object_label(sample)
-            batch['attribute_labels'][index,:] = \
-                self.get_attribute_label(sample)
+            batch['images'][index, :, :, :], read_success = \
+                self.get_region_image(sample)
+            if read_success:
+                batch['object_labels'][index, :] = self.get_object_label(sample)
+                batch['attribute_labels'][index,:] = \
+                    self.get_attribute_label(sample)
 
         return batch
-    
-    def get_single(self, sample):
-        print sample
-        batch = dict()
-        batch['region_ids'] = dict()
-        batch['images'] = np.zeros(
+
+    def get_single(self, sample, result, index):
+        try:
+#        sample, lock = sample_lock
+#        lock.acquire()
+#        print sample
+#        self.num_threads += 1
+#        lock.release()
+            batch = dict()
+            batch['region_ids'] = dict()
+            batch['images'] = np.zeros(
             [self.h, self.w, self.c], np.float32)
-        batch['object_labels'] = np.zeros(
-            [len(self.object_labels_dict)], np.float32)
-        batch['attribute_labels'] = np.zeros(
-            [len(self.attribute_labels_dict)], np.float32)
+            batch['object_labels'] = np.zeros(
+                [len(self.object_labels_dict)], np.float32)
+            batch['attribute_labels'] = np.zeros(
+                [len(self.attribute_labels_dict)], np.float32)
+            
+            batch['region_ids'] = self.sample_to_region_dict[sample]
+            batch['images'], read_success = self.get_region_image(sample)
+            if read_success:
+                batch['object_labels'] = self.get_object_label(sample)
+                batch['attribute_labels'] = self.get_attribute_label(sample)
 
-        batch['region_ids'] = self.sample_to_region_dict[sample]
-        batch['images'] = self.get_region_image(sample)
-        batch['object_labels'] = self.get_object_label(sample)
-        batch['attribute_labels'] = self.get_attribute_label(sample)
+            result[index] = batch
+        except Exception, e:
+            print str(e)
+            print threading.current_thread().name + ' errored'
+
+#        return batch
 
-        return batch
 
     def get_parallel(self, samples):
-        batch_list = self.pool.map(self.get_single, samples)
+#        batch_list = self.pool.map(self.get_single, samples)
+        self.num_threads = 0
+        q = Queue.Queue()
+        batch_list = [None]*len(samples)
+        count = 0
+        indices = range(len(samples))
+        jobs = []
+        for sample in samples:
+            worker = threading.Thread(target=self.get_single, args = (sample, batch_list, indices[count]))
+            worker.setDaemon(True)
+            worker.start()
+            jobs.append(worker)
+            count += 1
+        
+        for worker in jobs:
+            worker.join()
+
+         # while threading.active_count() > 0:
+         #    print '#Threads: {}'.format(threading.active_count())
+         #    pass
+        # lock = thread.allocate_lock()
+        # for sample in samples:
+        #     thread.start_new_thread(self.get_single, (sample,lock,q))
+        #     while (self.num_threads > 4):
+        #         pass
+        # while (self.num_threads > 0):
+        #     pass
+        # thread.exit()
+
+        # batch_list = []
+        # while not q.empty():
+        #     batch_list.append(q.get())
+#        pool = Pool(3)
+#        batch_list = pool.map(unwrap_self_get_single, zip([self]*len(samples),samples))
+        
         batch_size = len(samples)
         batch = dict()
         batch['region_ids'] = dict()
@@ -109,24 +164,25 @@ class data():
             batch['object_labels'][index, :] = single_batch['object_labels']
             batch['attribute_labels'][index,:] = single_batch['attribute_labels']
 
-
         return batch
-        
+
     def get_region_image(self, sample):
         region_id = self.sample_to_region_dict[sample]
         region = self.regions[region_id]
-        filename = os.path.join(self.image_dir,
-                                str(region['image_id']) + '.jpg')
-        image = image_io.imread(filename)
-        image = self.single_to_three_channel(image)
-        x, y, w, h = self.get_clipped_region_coords(region, image.shape[0:2])
-        region_image = image[y:y + h, x:x + w, :]
-
-        region_image = image_io.imresize(
-            region_image,
-            output_size=(self.h, self.w)).astype(np.float32)
+        image_subdir = os.path.join(self.image_dir,
+                                    str(region['image_id']))
+        filename = os.path.join(image_subdir,
+                                str(region_id) + '.jpg')
+        read_success = True
+        try:
+            region_image = image_io.imread(filename)
+        except:
+            read_success = False
+            region_image = np.zeros([self.h, self.w], dtype)
+        
+        region_image = region_image.astype(np.float32)
 
-        return region_image / 255 - self.mean_image
+        return region_image / 255 - self.mean_image, read_success
 
     def single_to_three_channel(self, image):
         if len(image.shape)==3:
@@ -223,11 +279,13 @@ if __name__=='__main__':
     print "Region: {}".format(region)
     print "Attributes: {}".format(", ".join(attributes))
     print "Objects: {}".format(", ".join(objects))
-    
+
+#    pdb.set_trace()
+    data_mgr.regions['5581152']
 #    image_io.imshow(region_image)
 
     batch_size = 200
-    num_samples = 1000
+    num_samples = 10000
     num_epochs = 1
     offset = 0
 
@@ -249,7 +307,7 @@ if __name__=='__main__':
     batch_generator = tftools.data.async_batch_generator(
         data_mgr, 
         index_generator, 
-        1000)
+        100)
 
     count = 0 
     start = time.time()
@@ -261,4 +319,4 @@ if __name__=='__main__':
     print 'Time per batch: {}'.format((stop-start)/50.0)
     print "Count: {}".format(count)
 
-    pool.close()
+    
diff --git a/tftools/data.py b/tftools/data.py
index 9fe60c1..5c00c8a 100644
--- a/tftools/data.py
+++ b/tftools/data.py
@@ -1,6 +1,6 @@
 """Utility functions for making batch generation easier."""
 import numpy as np
-from multiprocessing import Process, Queue
+from multiprocessing import Process, Queue, Manager, Pool
 import time
 
 def sequential(batch_size, num_samples, num_epochs=1, offset=0):
@@ -26,7 +26,7 @@ def batch_generator(data, index_generator, batch_function=None):
     """Generate batches of data.
     """
     for samples in index_generator:
-        batch = data.get(samples)
+        batch = data.get_parallel(samples)
         if batch_function:
             output = batch_function(batch)
         else:
@@ -43,7 +43,7 @@ def async_batch_generator(data, index_generator, queue_maxsize, batch_function=N
     fetcher = BatchFetcher(queue, batcher)
     fetcher.start()
 
-    time.sleep(10)
+    time.sleep(1)
     
     queue_batcher = queue_generator(queue)
     return queue_batcher
@@ -74,6 +74,20 @@ class BatchFetcher(Process):
         # Signal that fetcher is done.
         self.queue.put(None)
 
+class BatchFetcherParallel(Process):
+    def __init__(self, queue, batch_generator):
+        super(BatchFetcher, self).__init__()
+        self.queue = queue
+        self.batch_generator = batch_generator
+
+    def run(self):
+        
+        for batch in self.batch_generator:
+            self.queue.put(batch)
+
+        # Signal that fetcher is done.
+        self.queue.put(None)
+
 
 class NumpyData(object):
     def __init__(self, array):
@@ -85,6 +99,39 @@ class NumpyData(object):
         return self.array[indices]
 
 
+# def write_to_queue((queue, samples, data)):
+#     print samples
+#     if samples:
+#         batch = data.get(samples)
+#         queue.put(batch)
+#     else:
+#         queue.put(None)
+
+# def create_generator(queue, index_generator, data):
+#     for samples in index_generator:
+#         yield (queue, samples, data)
+
+# def parallel_async_batch_generator(data, index_generator, queue_maxsize, batch_function=None):
+#     pool = Pool(processes = 3)
+#     m = Manager()
+#     queue = m.Queue(maxsize=queue_maxsize)
+#     gen = create_generator(queue, index_generator, data)
+#     count =0
+#     for gen_sample in gen:
+#         gen_list = []
+#         while(count<3):
+#             gen_list.append(gen_sample)
+#             count += 1
+#         pool.map(write_to_queue, gen_list)        
+#     workers = pool.map(write_to_queue, (queue,[],data))
+#     print 'Here Already'
+# #    enqueue_process = Process(target = pool.map, args=(write_to_queue, gen))
+#     queue_batcher = queue_generator(queue)
+#     return queue_batcher
+    
+
+    
+
 if __name__=='__main__':
     batch_size = 10
     num_samples = 20
diff --git a/visual_genome_parser.py b/visual_genome_parser.py
index d104401..8847228 100644
--- a/visual_genome_parser.py
+++ b/visual_genome_parser.py
@@ -24,6 +24,7 @@ _object_labels = 'object_labels.json'
 _attribute_labels = 'attribute_labels.json'
 _regions_with_labels = 'region_with_labels.json'
 _unknown_token = 'UNK'
+_unopenable_images = 'unopenable_images.json'
 _im_w = 224
 _im_h = 224
 
@@ -264,12 +265,39 @@ def top_k_attribute_labels(k):
         json.dump(attribute_labels, outfile, sort_keys=True, indent=4)
  
 
+existing_region = 0
+not_existing_region = 0
+unopenable_images_list = []
+unopenable_images_filename = os.path.join(_datadir, _unopenable_images)
 def crop_region(region_info):
+    global existing_region 
+    global not_existing_region 
+    global unopenable_images_list
     region_id, region_data = region_info
     image_filename = os.path.join(_datadir, 
                                   'images/' + 
                                   str(region_data['image_id']) + '.jpg')
-    image = image_io.imread(image_filename)
+    image_subdir = os.path.join(_cropped_regions_dir, 
+                                str(region_data['image_id']))
+    image_out_filename = os.path.join(image_subdir, 
+                                      str(region_id) + '.jpg')
+    if os.path.exists(image_out_filename):
+        existing_region += 1
+        return
+    else:
+        not_existing_region += 1
+#        print region_info, existing_region, not_existing_region
+
+    if not os.path.exists(image_subdir):
+        os.mkdir(image_subdir)
+        
+
+    try:
+        image = image_io.imread(image_filename)
+    except:
+        print image_filename
+        unopenable_images_list.append(image_filename)
+        return
     
     if len(image.shape)==3:
         im_h, im_w, im_c =image.shape
@@ -287,18 +315,13 @@ def crop_region(region_info):
     
     cropped_region = image_io.imresize(image[y:y+h,x:x+w,:],
                                        output_size=(_im_h, _im_w))
-    image_subdir = os.path.join(_cropped_regions_dir, 
-                                str(region_data['image_id']))
-    if not os.path.exists(image_subdir):
-        os.mkdir(image_subdir)
-        
-    image_out_filename = os.path.join(image_subdir, 
-                                      str(region_id) + '.jpg')
-        
     image_io.imwrite(cropped_region, image_out_filename)
 
 
 def crop_regions_parallel():
+    global unopenable_images_filename
+    global unopenable_images_list
+    global existing_region
     regions_filename = os.path.join(_outdir,
                                     _regions_with_labels)
     with open(regions_filename) as file:
@@ -307,13 +330,20 @@ def crop_regions_parallel():
     if not os.path.exists(_cropped_regions_dir):
         os.mkdir(_cropped_regions_dir)
 
-    pool = Pool(24)
-    try:
-        pool.map(crop_region, regions.items())
-    except:
-        pool.close()
-    pool.close()
-
+    for region_id, region_data in regions.items():
+        crop_region((region_id, region_data))
+
+    print existing_region
+    # pool = Pool(10)
+    # try:
+    #     pool.map(crop_region, regions.items())
+    # except:
+    #     pool.close()
+    #     raise
+    # pool.close()
+
+    with open(unopenable_images_filename, 'w') as outfile:
+        json.dump(unopenable_images_list, outfile, sort_keys=True, indent=4)
 
 def crop_regions():
     regions_filename = os.path.join(_outdir,
-- 
GitLab