Skip to content
Snippets Groups Projects
Commit dbcfe37f authored by Kevin Shih's avatar Kevin Shih
Browse files

adding shape dataset generation

parents
No related branches found
No related tags found
No related merge requests found
#! /bin/python2
import collections
import sys
import os
import json
import random
from PIL import Image, ImageDraw
from shape_globals import Globals
SHAPES = Globals.SHAPES
COLORS = Globals.COLORS
# shapes: blank, square, triangle, circle (0-3)
# color: red, green, blue, white (DNA if shapes is blank) (0-3)
Block = collections.namedtuple('Block', ['shape', 'color'])
QA = collections.namedtuple('QA', ['q', 'a'])
eps = 10;
OUTDIR = ''
im_id_cnt = 1
qs_id_cnt = 1
num_to_gen = 15000
train_im_id_max = 9999
first = True
f_handle = ''
def gen_image(block_array):
image = Image.new('RGBA', (300, 300))
draw = ImageDraw.Draw(image);
cnt = 0;
for blk in block_array:
x1 = (cnt % 3)*100+1 + eps
x2 = x1+100 - 2*eps
y1 = cnt/3*100 +eps
y2 = y1+100 - 2*eps
if blk.shape > 0:
if blk.shape == 1: # square
draw.rectangle((x1,y1, x2, y2), fill = COLORS[blk.color], outline = 'black')
elif blk.shape == 2:
draw.polygon((x1,y1, x2, y1, (x1+x2)/2, y2), fill = COLORS[blk.color], outline = 'black')
else :
draw.ellipse((x1,y1, x2, y2), fill = COLORS[blk.color], outline = 'black')
#else: # diamond
#draw.polygon((x1,(y1+y2)/2, (x1+x2)/2, y1, x2, (y1+y2)/2, (x1+x2)/2, y2), fill = COLORS[blk.color], outline = 'black')
cnt = cnt+1
del draw
return image
def gen_questions(block_array):
shapecoltable= [ [0]*3 for i in range(3)]
# is there a
num_shapes = 0
qas = []
for blk in block_array:
if blk.shape > 0:
#col = COLORS[blk.color]
#shp = SHAPES[blk.shape]
shapecoltable[blk.shape-1][blk.color] +=1
num_shapes += 1
for r in range(3):
for c in range(3):
ques = 'Is there a %s %s?' % (COLORS[c], SHAPES[r+1])
if random.randint(1,2) == 1: # add with 50% probability
if shapecoltable[r][c] == 0:
ans = 'no'
if shapecoltable[r][c] >0:
ans = 'yes'
qas.append(QA(ques, ans))
# how many
for r in range(3):
for c in range(3):
if shapecoltable[r][c] > 1:
ques = 'How many %s %ss are there?' %(COLORS[c], SHAPES[r+1])
ans = '%d' %(shapecoltable[r][c])
qas.append(QA(ques, ans))
ques ='How many shapes are there?'
ans = '%d' %(num_shapes)
qas.append(QA(ques,ans))
# what color
for r in range(3):
if sum(shapecoltable[r]) == 1:
ques = 'What color is the %s?' %(SHAPES[r+1])
for c in range(3):
if shapecoltable[r][c] == 1:
ans = '%s' %(COLORS[c])
qas.append(QA(ques,ans))
break;
# relative position (above/below)
cnt = 0;
for blk in block_array:
if blk.shape > 0 and cnt/3 < 2:
if shapecoltable[blk.shape-1][blk.color] == 1:
# unique shape/color
if block_array[cnt+3].shape > 0:
b_c = COLORS[block_array[cnt+3].color]
b_s = SHAPES[block_array[cnt+3].shape]
ques = 'Is there a %s %s below a %s %s?' % (b_c, b_s, COLORS[blk.color], SHAPES[blk.shape])
ans = 'yes'
qas.append(QA(ques, ans))
# now perturb the question
b_c_p = COLORS[(block_array[cnt+3].color+1)%3]
ques = 'Is there a %s %s below a %s %s?' % (b_c_p, b_s, COLORS[blk.color], SHAPES[blk.shape])
ans = 'no'
qas.append(QA(ques, ans))
cnt = cnt+1
# for qa in qas:
# print(qa)
return qas
def enumerate_all(block_array):
global im_id_cnt
if im_id_cnt == num_to_gen:
return
# base case
if len(block_array) == 9:
global qs_id_cnt
global first
# print block_array
qas = gen_questions(block_array)
im = gen_image(block_array)
im.save(OUTDIR + '/' + str(im_id_cnt) + '.jpg')
if im_id_cnt % 100 == 0:
print(im_id_cnt)
# write the questions to 'f_handle'
for qa in qas:
if first:
f_handle.write('{')
first = False
else:
f_handle.write(',{')
qstr = "\"question_id\": %d, \"image_id\" : %d, \"question\": \"%s\", \"answer\": \"%s\", \"config\": [ " %(qs_id_cnt, im_id_cnt, qa.q, qa.a)
f_handle.write(qstr)
for i in range(len(block_array)):
blk = block_array[i]
if i == len(block_array) -1:
blk_str = "%d, %d ]" %(blk.shape, blk.color)
else:
blk_str = "%d, %d, " %(blk.shape, blk.color)
f_handle.write(blk_str)
f_handle.write('}\n')
qs_id_cnt = qs_id_cnt + 1
im_id_cnt = im_id_cnt+1
else:
order = range(4);
corder = range(3);
random.shuffle(corder);
random.shuffle(order);
if im_id_cnt <= train_im_id_max:
for s in order:
if s > 0:
for c in corder:
if (s == 1 and c == 0) or (s == 2 and c == 1) or (s == 3 and c == 2):
continue;
enumerate_all(block_array + [Block(s,c)])
else:
enumerate_all(block_array + [Block(0,0)])
else:
for s in order:
if s > 0:
for c in corder:
enumerate_all(block_array + [Block(s,c)])
else:
enumerate_all(block_array + [Block(0,0)])
if __name__ == "__main__":
outdir = sys.argv[1]
if not os.path.exists(outdir):
os.makedirs(outdir)
anno_path = sys.argv[2]
random.seed(42)
OUTDIR = outdir
f = open(sys.argv[2], 'w')
f_handle = f
f.write('[')
enumerate_all([])
f.write(']')
f.close()
# [a b c; d e f; g h i]
# a = Block(3, 0);
# b = Block(2, 2);
# c = Block(2, 0);
# d = Block(0, 0);
# e = Block(2, 1);
# f = Block(0, 0);
# g = Block(1, 2);
# h = Block(2, 0);
# i = Block(2, 2);
# gen_questions([a, b, c, d, e, f, g, h, i])
# gen_image([a, b, c, d, e, f, g, h, i])
#! /bin/sh
#mkdir -p images
#python2 gen_dataset.py images anno.json
#python2 gen_regions.py anno.json regions_anno.json
python2 split_dataset.py anno.json train_anno.json test_anno.json
#! /bin/python2
import json
import sys
from shape_globals import Globals
SHAPES = Globals.SHAPES
COLORS = Globals.COLORS
# given string phrase and image configuration, return bounding box in 300x300 image
# box is in [x1 y1 x2 y2]
def get_region(config, phrase):
region_bbox = [-1,-1,-1,-1]
phrase_toks = phrase.split(' ');
if len(phrase_toks) == 1:
# just a shape
for i in range(0, len(config), 2):
if SHAPES[config[i]] == phrase_toks[0]:
cnt = i/2
x1 = (cnt % 3)*100+1
x2 = x1+99
y1 = cnt/3*100+1
y2 = y1+99
region_bbox = [x1, y1, x2, y2]
break
elif len(phrase_toks) == 2: # color+shape
if phrase_toks[1][-1] == 's': # drop the plural
phrase_toks[1] = phrase_toks[1][:-1]
for i in range(0, len(config), 2):
if SHAPES[config[i]] == phrase_toks[1] and COLORS[config[i+1]] == phrase_toks[0]:
cnt = i/2
x1 = (cnt % 3)*100+1
x2 = x1+99
y1 = cnt/3*100+1
y2 = y1+99
region_bbox = [x1, y1, x2, y2]
break
elif len(phrase_toks) == 6:
for i in range(0, 12,2): # iterate through first 2 rows
if SHAPES[config[i+6]] == phrase_toks[1] and COLORS[config[i+7]] == phrase_toks[0] and SHAPES[config[i]] == phrase_toks[5] and COLORS[config[i+1]] == phrase_toks[4]:
cnt1 = i/2
x1 = (cnt1 % 3)*100+1
x2 = x1+99
y1 = cnt1/3*100+1
y2 = y1+199
region_bbox = [x1, y1, x2, y2]
return region_bbox
if __name__ == "__main__":
f = open(sys.argv[1], 'r') # path to 'anno.json'
f_out = open(sys.argv[2], 'w') # output json file
annos = json.load(f)
curr_im_id = 1
regions_dict = {}
anno_out = []
im_regions = {}
for anno in annos:
if anno['image_id'] != curr_im_id: # next image, dump regions
anno_out.append({'image_id' : curr_im_id, 'regions': im_regions})
curr_im_id = anno['image_id']
im_regions = {}# reset the set
qs = anno['question'][:-1]; # ignore the '?'
toks = qs.split(" ")
np_1 = None
np_1_bbox = None
np_2 = None
np_2_bbox = None
rel = None
rel_bbox = None
region_dict = {} # maps im_id to region coordinates
if "Is there a " in qs:
# existence
if anno['answer'] == "no":
# existence is false, nothing to match
continue
np_1 = ' '.join(toks[3:5])
if "below" in qs:
np_2 = ' '.join(toks[7:9])
rel = ' '.join(toks[3:9])
elif "What color is the" in qs:
np_1 = toks[-1]
else:
continue
if not np_1 in im_regions.keys():
np_1_bbox = get_region(anno['config'], np_1)
im_regions[np_1] = np_1_bbox
if not np_2 == None and not np_2 in im_regions.keys():
np_2_bbox = get_region(anno['config'], np_2)
im_regions[np_2] = np_2_bbox
if not rel == None:
rel_bbox = get_region(anno['config'], rel)
im_regions[rel] = rel_bbox
anno_out.append({'image_id' : curr_im_id, 'regions': im_regions})
curr_im_id = anno['image_id']
dump_str = json.dumps(anno_out, sort_keys=True, indent=4, separators = (',',': '))
f_out.write(dump_str)
f.close()
f_out.close()
#! /bin/python2
class Globals(object):
SHAPES=['blank', 'square', 'triangle', 'circle']
COLORS=['red', 'green', 'blue']
#! /bin/python2
import json
import sys
if __name__ == "__main__":
full_anno_f = sys.argv[1]
train_anno_f = sys.argv[2]
test_anno_f = sys.argv[3]
with open(full_anno_f, 'r') as anno_file:
data = json.load(anno_file)
train_data = [q for q in data if q['image_id'] < 10000]
test_data = [q for q in data if q['image_id'] >= 10000]
train_f = open(train_anno_f, 'w')
test_f = open(test_anno_f, 'w')
json.dump(train_data, train_f)
json.dump(test_data, test_f)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment