Skip to content
Snippets Groups Projects
split_genome.py 7.94 KiB
import ujson
import hashlib
import os
import glob
import numpy as np
import constants
import image_io
import pdb

def vqa_filename_to_image_id(
        filename):
    _, image_name = os.path.split(filename)
    return str(int(image_name.split('_')[-1].strip('.jpg')))

def genome_filename_to_image_id(
        filename):
    _, image_name = os.path.split(filename)
    return image_name.strip('.jpg')

def image_id_to_vqa_filename(
        image_dir,
        image_id,
        mode):

    filename = os.path.join(
        image_dir,
        'COCO_' + mode + '_' + image_id.zfill(12) + '.jpg')
    
    return filename

def image_id_to_genome_filename(
        image_dir,
        image_id):
    
    filename = os.path.join(
        image_dir,
        image_id + '.jpg')

    return filename

def generate_md5hash(
        image_dir,
        hash_json,
        filename_to_image_id):
    
    files = glob.glob(os.path.join(image_dir,'*.jpg'))
    md5_hashes = dict()
    count = 0
    for file in files:
        count+=1
        # print count, len(files)
        try:
            im = image_io.imread(file)
            md5_hashes[hashlib.md5(im).hexdigest()] = \
                filename_to_image_id(file)
        except:
            print 'Can not read {}'.format(file)
            pass
    
    with open(hash_json, 'w') as file:
        ujson.dump(md5_hashes, file)


if __name__=='__main__':
    genome_image_dir = os.path.join(
        constants.data_absolute_path,
        'images')

    vqa_train2014_image_dir = os.path.join(
        constants.vqa_basedir,
        'train2014')

    vqa_val2014_image_dir = os.path.join(
        constants.vqa_basedir,
        'val2014')

    vqa_test2015_image_dir = os.path.join(
        constants.vqa_basedir,
        'test2015')

    vqa_train2014_held_out_qids_json = os.path.join(
        constants.vqa_basedir,
        'train_held_out_qids.json')

    vqa_train2014_anno_json = os.path.join(
        constants.vqa_basedir,
        'mscoco_train2014_annotations_with_parsed_questions.json')
    
    genome_hash_json = os.path.join(
        constants.data_absolute_path,
        'genome_hash.json')

    vqa_train2014_hash_json = os.path.join(
        constants.vqa_basedir,
        'train2014_hash.json')

    vqa_val2014_hash_json = os.path.join(
        constants.vqa_basedir,
        'val2014_hash.json')

    vqa_test2015_hash_json = os.path.join(
        constants.vqa_basedir,
        'test2015_hash.json')

    genome_not_in_vqa_train2014_subset_region_ids_json = os.path.join(
        constants.data_absolute_path,
        'restructured/genome_not_in_vqa_train2014_subset_region_ids.json')

    genome_train_subset_region_ids_json = os.path.join(
        constants.data_absolute_path,
        'restructured/train_subset_region_ids.json')

    genome_train_held_out_region_ids_json = os.path.join(
        constants.data_absolute_path,
        'restructured/train_held_out_region_ids.json')

    genome_test_region_ids_json = os.path.join(
        constants.data_absolute_path,
        'restructured/test_region_ids.json')

   # print 'Generating {}'.format(genome_hash_json)
    # generate_md5hash(
    #     genome_image_dir,
    #     genome_hash_json,
    #     genome_filename_to_image_id)

    # print 'Generating {}'.format(vqa_train2014_hash_json)
    # generate_md5hash(
    #     vqa_train2014_image_dir,
    #     vqa_train2014_hash_json,
    #     vqa_filename_to_image_id)

    # print 'Generating {}'.format(vqa_val2014_hash_json)
    # generate_md5hash(
    #     vqa_val2014_image_dir,
    #     vqa_val2014_hash_json,
    #     vqa_filename_to_image_id)

    # print 'Generating {}'.format(vqa_test2015_hash_json)
    # generate_md5hash(
    #     vqa_test2015_image_dir,
    #     vqa_test2015_hash_json,
    #     vqa_filename_to_image_id)


    print 'Reading {} ...'.format(vqa_train2014_anno_json)
    with open(vqa_train2014_anno_json, 'r') as file:
        vqa_train2014_anno = ujson.load(file)

    print 'Reading {} ...'.format(vqa_train2014_held_out_qids_json)
    with open(vqa_train2014_held_out_qids_json, 'r') as file:
        vqa_train2014_held_out_qids = ujson.load(file)

    
    vqa_train2014_held_out_image_ids = set()
    for qid in list(vqa_train2014_held_out_qids):
        image_id = str(vqa_train2014_anno[qid]['image_id'])
        if image_id not in vqa_train2014_held_out_image_ids:
            vqa_train2014_held_out_image_ids.add(image_id)
        
    # The first one should always be genome hash
    hashes_to_load = [
        genome_hash_json,
        vqa_train2014_hash_json,
        vqa_val2014_hash_json,
        vqa_test2015_hash_json]

    hashes = []
    for hash_to_load in hashes_to_load:
        with open(hash_to_load,'r') as file:
            hashes.append(ujson.load(file))        

    image_assignment = dict()
    for hash_key, image_id in hashes[0].items():
        image_assignment[image_id] = 0
        for i, hash in enumerate(hashes[1:]):
            if hash_key in hash:
                if i==0 and hash[hash_key] in vqa_train2014_held_out_image_ids:
                    image_assignment[image_id] = 4
                else:
                    image_assignment[image_id] = i + 1

                # if i==2:
                #     vqa_filename = image_id_to_vqa_filename(
                #         vqa_test2015_image_dir,
                #         hash[hash_key],
                #         'test2015')

                #     genome_filename = image_id_to_genome_filename(
                #         genome_image_dir,
                #         image_id)

                #     vqa_im = image_io.imread(vqa_filename)
                #     genome_im = image_io.imread(genome_filename)

                #     image_io.imshow(vqa_im)
                #     image_io.imshow(genome_im)
                #     print vqa_filename, genome_filename
                #     pdb.set_trace()
            
    counts = {
        0: 0,
        1: 0,
        2: 0,
        3: 0,
        4: 0,
    }

    for key, value in image_assignment.items():
        counts[value] += 1

    print counts

    with open(constants.regions_json, 'r') as file:
        regions_data = ujson.load(file)

    splits = [
        'unknown',
        'train2014_subset',
        'val2014',
        'test2015',
        'train2014_held_out']

    sets = [set() for i in xrange(len(splits))]
    
    for region_id, region_data in regions_data.items():
        if str(region_data['image_id']) in image_assignment:
            i = image_assignment[str(region_data['image_id'])]
            sets[i].add(region_id)

    pdb.set_trace()
    genome_safe_region_ids = list(sets[0]) + list(sets[1])
    genome_test_region_ids = list(sets[2]) + list(sets[3])

    with open(genome_not_in_vqa_train2014_subset_region_ids_json, 'w') as file:
        ujson.dump(genome_safe_region_ids, file)

    with open(genome_test_region_ids_json, 'w') as file:
        ujson.dump(genome_test_region_ids, file)

    split_frac = 0.9
    num_valid_train_regions = len(genome_safe_region_ids)
    num_genome_train_subset_regions = int(num_valid_train_regions*split_frac)
    num_genome_train_held_out_regions = num_valid_train_regions - num_genome_train_subset_regions
    train_subset_ids = set(
        np.random.choice(
            num_valid_train_regions, 
            num_genome_train_subset_regions,
            replace=False))

    genome_train_subset_region_ids = set()
    genome_train_held_out_region_ids = set()
    for i in xrange(num_valid_train_regions):
        if i in train_subset_ids:
            genome_train_subset_region_ids.add(genome_safe_region_ids[i])
        else:
            genome_train_held_out_region_ids.add(genome_safe_region_ids[i])
        
    print len(genome_train_subset_region_ids), len(genome_train_held_out_region_ids)
    with open(genome_train_subset_region_ids_json, 'w') as file:
        ujson.dump(genome_train_subset_region_ids, file)

    with open(genome_train_held_out_region_ids_json, 'w') as file:
        ujson.dump(genome_train_held_out_region_ids, file)
        
    pdb.set_trace()