From a39ea4335eb924918fb6f56f70f73b2f888f1dea Mon Sep 17 00:00:00 2001 From: Kevin Shih <kevin.j.shih@gmail.com> Date: Tue, 29 Mar 2016 00:12:37 -0500 Subject: [PATCH] switching to random generation instead of recursive enumeration --- shapes_dataset/gen_dataset.py | 69 ++++++++++++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/shapes_dataset/gen_dataset.py b/shapes_dataset/gen_dataset.py index 80c66c5..7cf4e9f 100755 --- a/shapes_dataset/gen_dataset.py +++ b/shapes_dataset/gen_dataset.py @@ -134,6 +134,68 @@ def gen_questions(block_array): return qas +def enumerate_all_sample(target_set_size, on_train, eset = None): + global im_id_cnt + global qs_id_cnt + global first + + out_set = set() + empty_config = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0] + while len(out_set) < target_set_size : + config = list(empty_config) + for i in xrange(0, 18,2): + sample = random.randint(0, 4) - 1 + if sample <= 0: + config[i] = 0 + else: + config[i] = sample # set shape + rand_col = random.randint(0,2) + config[i+1] = rand_col # set color + if on_train == 1 and sample == rand_col +1: # enforce held-out + rand_col = (rand_col + random.choice([1,-1]))%3 + config[i+1] = rand_col + + config_str = ''.join(str(c) for c in config) + if eset != None and config_str in eset: # skip if it's in the eset (pre-existing set) + continue + # now add to set + out_set.add(config_str) + + # generate the examples + for config in out_set: + block_array = [] + for i in xrange(0, 18, 2): + block_array += [Block(int(config[i]), int(config[i+1])) ] + + qas = gen_questions(block_array) + im = gen_image(block_array) + im.save(OUTDIR + '/' + str(im_id_cnt) + '.jpg') + + + for qa in qas: + if first: + f_handle.write('{') + first = False + else: + f_handle.write(',{') + qstr = "\"question_id\": %d, \"image_id\" : %d, \"question\": \"%s\", \"answer\": \"%s\", \"config\": [ " %(qs_id_cnt, im_id_cnt, qa.q, qa.a) + f_handle.write(qstr) + for i in xrange(len(block_array)): + blk = block_array[i] + if i == len(block_array) -1: + blk_str = "%d, %d ]" %(blk.shape, blk.color) + else: + blk_str = "%d, %d, " %(blk.shape, blk.color) + + f_handle.write(blk_str) + + f_handle.write('}\n') + qs_id_cnt = qs_id_cnt + 1 + + im_id_cnt+=1 + + return out_set + def enumerate_all(block_array): global im_id_cnt if im_id_cnt == num_to_gen: @@ -211,7 +273,12 @@ if __name__ == "__main__": f_handle = f f.write('[') - enumerate_all([]) + #enumerate_all([]) + print "Generating training examples..." + config_set = enumerate_all_sample(10000, 1, eset = None) + print "Generating test examples..." + # pass the training configs in when generating test examples to avoid duplicates + enumerate_all_sample(5000, 0, eset = config_set) f.write(']') f.close() # [a b c; d e f; g h i] -- GitLab