From a39ea4335eb924918fb6f56f70f73b2f888f1dea Mon Sep 17 00:00:00 2001
From: Kevin Shih <kevin.j.shih@gmail.com>
Date: Tue, 29 Mar 2016 00:12:37 -0500
Subject: [PATCH] switching to random generation instead of recursive
 enumeration

---
 shapes_dataset/gen_dataset.py | 69 ++++++++++++++++++++++++++++++++++-
 1 file changed, 68 insertions(+), 1 deletion(-)

diff --git a/shapes_dataset/gen_dataset.py b/shapes_dataset/gen_dataset.py
index 80c66c5..7cf4e9f 100755
--- a/shapes_dataset/gen_dataset.py
+++ b/shapes_dataset/gen_dataset.py
@@ -134,6 +134,68 @@ def gen_questions(block_array):
     return qas
 
 
+def enumerate_all_sample(target_set_size, on_train, eset = None):
+    global im_id_cnt
+    global qs_id_cnt
+    global first
+
+    out_set = set()
+    empty_config = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
+    while len(out_set) < target_set_size :
+        config = list(empty_config)
+        for i in xrange(0, 18,2):
+            sample = random.randint(0, 4) - 1
+            if sample <= 0:
+                config[i] = 0
+            else:
+                config[i] = sample # set shape
+                rand_col = random.randint(0,2)
+                config[i+1] = rand_col # set color
+                if on_train == 1 and sample == rand_col +1: # enforce held-out
+                    rand_col = (rand_col + random.choice([1,-1]))%3
+                    config[i+1] = rand_col
+
+        config_str = ''.join(str(c) for c in config)
+        if eset != None and config_str in eset: # skip if it's in the eset (pre-existing set)
+            continue
+        # now add to set
+        out_set.add(config_str)
+
+    # generate the examples        
+    for config in out_set:
+        block_array = []
+        for i in xrange(0, 18, 2):
+            block_array += [Block(int(config[i]), int(config[i+1])) ]
+                    
+        qas = gen_questions(block_array)
+        im = gen_image(block_array)
+        im.save(OUTDIR + '/' +  str(im_id_cnt) + '.jpg')
+
+
+        for qa in qas:
+            if first:
+                f_handle.write('{')
+                first = False
+            else:
+                f_handle.write(',{')
+            qstr = "\"question_id\": %d, \"image_id\" : %d, \"question\": \"%s\", \"answer\": \"%s\", \"config\": [ " %(qs_id_cnt, im_id_cnt, qa.q, qa.a)
+            f_handle.write(qstr)
+            for i in xrange(len(block_array)):
+                blk = block_array[i]
+                if i == len(block_array) -1:
+                    blk_str = "%d, %d ]" %(blk.shape, blk.color)
+                else:
+                    blk_str = "%d, %d, " %(blk.shape, blk.color)
+
+                f_handle.write(blk_str)
+
+            f_handle.write('}\n')
+            qs_id_cnt = qs_id_cnt + 1
+                    
+        im_id_cnt+=1
+
+    return out_set
+
 def enumerate_all(block_array):
     global im_id_cnt
     if im_id_cnt == num_to_gen:
@@ -211,7 +273,12 @@ if __name__ == "__main__":
     f_handle = f
     f.write('[')
 
-    enumerate_all([])
+    #enumerate_all([])
+    print "Generating training examples..."
+    config_set = enumerate_all_sample(10000, 1, eset = None)
+    print "Generating test examples..."
+    # pass the training configs in when generating test examples to avoid duplicates
+    enumerate_all_sample(5000, 0, eset = config_set)
     f.write(']')
     f.close()
 #    [a b c; d e f; g h i]
-- 
GitLab