Skip to content
Snippets Groups Projects
Commit a39ea433 authored by Kevin Shih's avatar Kevin Shih
Browse files

switching to random generation instead of recursive enumeration

parent 784ac86c
No related branches found
No related tags found
No related merge requests found
......@@ -134,6 +134,68 @@ def gen_questions(block_array):
return qas
def enumerate_all_sample(target_set_size, on_train, eset = None):
global im_id_cnt
global qs_id_cnt
global first
out_set = set()
empty_config = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
while len(out_set) < target_set_size :
config = list(empty_config)
for i in xrange(0, 18,2):
sample = random.randint(0, 4) - 1
if sample <= 0:
config[i] = 0
else:
config[i] = sample # set shape
rand_col = random.randint(0,2)
config[i+1] = rand_col # set color
if on_train == 1 and sample == rand_col +1: # enforce held-out
rand_col = (rand_col + random.choice([1,-1]))%3
config[i+1] = rand_col
config_str = ''.join(str(c) for c in config)
if eset != None and config_str in eset: # skip if it's in the eset (pre-existing set)
continue
# now add to set
out_set.add(config_str)
# generate the examples
for config in out_set:
block_array = []
for i in xrange(0, 18, 2):
block_array += [Block(int(config[i]), int(config[i+1])) ]
qas = gen_questions(block_array)
im = gen_image(block_array)
im.save(OUTDIR + '/' + str(im_id_cnt) + '.jpg')
for qa in qas:
if first:
f_handle.write('{')
first = False
else:
f_handle.write(',{')
qstr = "\"question_id\": %d, \"image_id\" : %d, \"question\": \"%s\", \"answer\": \"%s\", \"config\": [ " %(qs_id_cnt, im_id_cnt, qa.q, qa.a)
f_handle.write(qstr)
for i in xrange(len(block_array)):
blk = block_array[i]
if i == len(block_array) -1:
blk_str = "%d, %d ]" %(blk.shape, blk.color)
else:
blk_str = "%d, %d, " %(blk.shape, blk.color)
f_handle.write(blk_str)
f_handle.write('}\n')
qs_id_cnt = qs_id_cnt + 1
im_id_cnt+=1
return out_set
def enumerate_all(block_array):
global im_id_cnt
if im_id_cnt == num_to_gen:
......@@ -211,7 +273,12 @@ if __name__ == "__main__":
f_handle = f
f.write('[')
enumerate_all([])
#enumerate_all([])
print "Generating training examples..."
config_set = enumerate_all_sample(10000, 1, eset = None)
print "Generating test examples..."
# pass the training configs in when generating test examples to avoid duplicates
enumerate_all_sample(5000, 0, eset = config_set)
f.write(']')
f.close()
# [a b c; d e f; g h i]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment