Skip to content
Snippets Groups Projects
vqa_cached_features.py 17.02 KiB
import numpy as np
import ujson
import os
import re
import pdb
import time
import nltk
import threading

import tftools.data 
import image_io
import constants

import tensorflow as tf

_unknown_token = constants.unknown_token

def unwrap_self_get_single(arg, **kwarg):
    return data.get_single(*arg, **kwarg)

class data():
    def __init__(self,
                 feat_dir,
                 anno_json,
                 qids_json,
                 vocab_json,
                 ans_vocab_json,
                 obj_labels_json,
                 atr_labels_json,
                 image_size,
                 num_region_proposals,
                 num_neg_answers,
                 channels=3,
                 mode='mcq',
                 resnet_feat_dim=2048,
                 mean_image_filename=None):
        self.feat_dir = feat_dir
        # data_split = re.split(
        #     '_',
        #     os.path.split(self.feat_dir)[1])[0]
        # pdb.set_trace()
        self.h = image_size[0]
        self.w = image_size[1]
        self.c = channels
        self.mode = mode
        self.resnet_feat_dim = 2048
        self.num_region_proposals = num_region_proposals
        self.num_neg_answers = num_neg_answers
        self.anno = self.read_json_file(anno_json)
        self.vocab = self.read_json_file(vocab_json)
        self.ans_vocab = self.read_json_file(ans_vocab_json)
        self.obj_labels = self.read_json_file(obj_labels_json)
        self.atr_labels = self.read_json_file(atr_labels_json)
        self.qids = self.read_json_file(qids_json)
        self.inv_vocab = self.invert_label_dict(self.vocab)
        self.inv_ans_vocab = self.invert_label_dict(self.ans_vocab)
        self.num_questions = len(self.anno)
        self.create_sample_to_question_dict()
        self.lemmatizer = nltk.stem.WordNetLemmatizer()

    def create_sample_to_question_dict(self):
        self.sample_to_question_dict = \
            {k: v for k, v in zip(xrange(len(self.qids)),
                                  self.qids)}

        # self.sample_to_question_dict = \
        #     {k: v for k, v in zip(xrange(self.num_questions),
        #                           self.anno.keys())}

    def invert_label_dict(self, label_dict):
        return {v: k for k, v in label_dict.items()}

    def read_json_file(self, filename):
        print 'Reading {} ...'.format(filename)
        with open(filename, 'r') as file:
            return ujson.load(file)

    def get_single(self, sample, batch_list, worker_id):
        batch = dict()

        batch['region_feats'] = self.get_region_feats(sample)

        question, nouns, adjectives, question_id, question_unencoded = \
            self.get_question(sample)
        batch['question_id'] = question_id

        batch['question'] = question
        batch['question_nouns'] = nouns
        batch['question_adjectives'] = adjectives
        batch['question_id'] = question_id
        batch['question_unencoded'] = question_unencoded

        positive_answer, nouns, adjectives, positive_answer_unencoded = \
            self.get_positive_answer(
                sample, self.mode)
        batch['positive_answer'] = positive_answer
        batch['positive_answer_nouns'] = nouns
        batch['positive_answer_adjectives'] = adjectives
        batch['positive_answer_unencoded'] = positive_answer_unencoded
        
        negative_answers, nouns, adjectives, negative_answers_unencoded = \
            self.get_negative_answers(
                sample, self.mode)
        batch['negative_answers'] = negative_answers
        batch['negative_answers_nouns'] = nouns
        batch['negative_answers_adjectives'] = adjectives 
        batch['negative_answers_unencoded'] = negative_answers_unencoded

        batch['positive_nouns'] = batch['question_nouns'] \
                                  + batch['positive_answer_nouns']
        batch['positive_adjectives'] = batch['question_adjectives'] \
                                       + batch['positive_answer_adjectives']

        batch['nouns_identity'] = []
        batch['nouns_identity'].append(
            [0]*len(batch['question_nouns']) + \
            [1]*len(batch['positive_answer_nouns']))
        for i in xrange(self.num_neg_answers):
            batch['nouns_identity'].append(
                [0]*len(batch['question_nouns']) + \
                [1]*len(batch['negative_answers_nouns'][i]))

        batch['adjectives_identity'] = []
        batch['adjectives_identity'].append(
            [0]*len(batch['question_adjectives']) + \
            [1]*len(batch['positive_answer_adjectives']))
        for i in xrange(self.num_neg_answers):
            batch['adjectives_identity'].append(
                [0]*len(batch['question_adjectives']) + \
                [1]*len(batch['negative_answers_adjectives'][i]))
                    
        _, batch['positive_nouns_vec_enc'] = self.noun_to_obj_id(
            batch['positive_nouns'],
            positive_answer_unencoded)
        _, batch['positive_adjectives_vec_enc'] = self.adj_to_atr_id(
            batch['positive_adjectives'],
            positive_answer_unencoded)

        batch['yes_no_num_feat'] = self.is_yes_no_num(
            positive_answer_unencoded,
            negative_answers_unencoded)

        batch_list[worker_id] = batch

    def get_parallel(self, samples):
        batch_list = [None]*len(samples)
        worker_ids = range(len(samples))
        workers = []
        for count, sample in enumerate(samples):
            self.get_single(sample, batch_list, worker_ids[count])
        #     worker = threading.Thread(
        #         target = self.get_single, 
        #         args = (sample, batch_list, worker_ids[count]))
        #     worker.setDaemon(True)
        #     worker.start()
        #     workers.append(worker)
        
        # for worker in workers:
        #     worker.join()

        batch_size = len(samples)
        batch = dict()
        for key in batch_list[0].keys():
            batch[key] = []
        
        for single_batch in batch_list:
            for key, value in single_batch.items():
                batch[key].append(value)

        return batch

    def is_yes_no_num(
            self, 
            positive_answer_unencoded,
            negative_answers_unencoded):

        feat = np.zeros([self.num_neg_answers+1,6])

        if 'yes' in positive_answer_unencoded:
            feat[0,0] = 1.0
        elif 'no' in positive_answer_unencoded:
            feat[0,1] = 1.0
        elif '0' in positive_answer_unencoded:
            feat[0,2] = 1.0
        elif '1' in positive_answer_unencoded:
            feat[0,3] = 1.0
        elif '2' in positive_answer_unencoded:
            feat[0,4] = 1.0
        elif '3' in positive_answer_unencoded:
            feat[0,5] = 1.0
        

        for i in xrange(1,self.num_neg_answers+1):
            if 'yes' in negative_answers_unencoded[i-1]:
                feat[i,0] = 1.0
            elif 'no' in negative_answers_unencoded[i-1]:
                feat[i,1] = 1.0
            elif '0' in negative_answers_unencoded:
                feat[i,2] = 1.0
            elif '1' in negative_answers_unencoded:
                feat[i,3] = 1.0
            elif '2' in negative_answers_unencoded:
                feat[i,4] = 1.0
            elif '3' in negative_answers_unencoded:
                feat[i,5] = 1.0

        return feat

    def noun_to_obj_id(self,list_of_noun_ids,pos_ans):
        obj_ids = [None]*len(list_of_noun_ids)
        vec_enc = np.zeros([1,constants.num_object_labels])
        if pos_ans.lower() in set(['no','0','none','nil']):
            return obj_ids, vec_enc

        for i, id in enumerate(list_of_noun_ids):
            if id in self.inv_vocab:
                if self.inv_vocab[id] in self.obj_labels:
                    obj_ids[i] = int(self.obj_labels[self.inv_vocab[id]])
                    vec_enc[0,obj_ids[i]] = 1.0
                else:
                    obj_ids[i] = -1
            else:
                obj_ids[i] = -1
                
        
        return obj_ids, vec_enc

    def adj_to_atr_id(self,list_of_adj_ids,pos_ans):
        atr_ids = [None]*len(list_of_adj_ids)
        vec_enc = np.zeros([1,constants.num_attribute_labels])
        if pos_ans.lower() in set(['no','0','none','nil']):
            return atr_ids, vec_enc

        for i, id in enumerate(list_of_adj_ids):
            if id in self.inv_vocab:
                if self.inv_vocab[id] in self.atr_labels:
                    atr_ids[i] = int(self.atr_labels[self.inv_vocab[id]])
                    vec_enc[0,atr_ids[i]] = 1.0
            else:
                atr_ids[i] = -1
                
        return atr_ids, vec_enc

    
    def get_region_feats(self, sample):
        question_id = self.sample_to_question_dict[sample]
        image_id = self.anno[str(question_id)]['image_id']
        data_split = re.split(
            '_',
            os.path.split(self.feat_dir)[1])[0]

        feat_path = os.path.join(
            self.feat_dir,
            'COCO_' + data_split + '_' + str(image_id).zfill(12) + '.npy')
        return np.load(feat_path)

    def get_single_image(self, sample, region_number, batch_list, worker_id):
        try:
            batch = dict()
            question_id = self.sample_to_question_dict[sample]
            region_image, read_success = self.get_region_image(
                sample,
                region_number)

            if not read_success:
                region_image = np.zeros(
                    [self.h, self.w, self.c], np.float32)

            batch_list[worker_id] = region_image

        except Exception, e:
            print 'Error in thread {}: {}'.format(
                threading.current_thread().name, str(e))

    def get_region_image(self, sample, region_number):
        question_id = self.sample_to_question_dict[sample]
        image_id = self.anno[str(question_id)]['image_id']
        image_subdir = os.path.join(
            self.image_dir,
            'COCO_train2014_' + str(image_id).zfill(12))
        
        filename = os.path.join(image_subdir,
                                str(region_number+1) + '.jpg')
        read_success = True
        try:
            region_image = image_io.imread(filename)
            region_image = region_image.astype(np.float32)
        except:
            # print 'Could not read image {}: Setting the image pixels to 0s'.format(
            #     filename)
            read_success = False
            region_image = np.zeros([self.h, self.w, 3], dtype=np.float32)

        return region_image, read_success
    
    def get_question(self, sample):
        question_id = self.sample_to_question_dict[sample]
        question_nouns = self.encode_sentence(
            ' '.join(self.anno[question_id]['question_nouns']))
        question_adjectives = self.encode_sentence(
            ' '.join(self.anno[question_id]['question_adjectives']))
        parsed_question = self.anno[question_id]['parsed_question']
        unencoded_question = self.anno[question_id]['question']

        encoded_parsed_question = dict()
        for bin, words in parsed_question.items():
            encoded_parsed_question[bin] = self.encode_sentence(words)
        return encoded_parsed_question, question_nouns, question_adjectives, \
            question_id, unencoded_question

    def get_positive_answer(self, sample, mode='mcq'):
        question_id = self.sample_to_question_dict[sample]
        if mode=='mcq':
            positive_answer = self.anno[question_id]['multiple_choice_answer'].lower()
            popular_answer = positive_answer

        else:
            answers = self.anno[question_id]['answers']
            answer_counts = dict()
            for answer in answers:
                answer_lower = answer['answer'].lower()
                if answer not in answer_counts:
                    answer_counts[answer_lower] = 1
                else:
                    answer_counts[answer_lower] += 1

            popular_answer = ''
            current_count = 0
            for answer, count in answer_counts.items():
                if count > current_count:
                    popular_answer = answer
                    current_count = count

        nouns, adjectives = self.get_nouns_adjectives(popular_answer)
        answer = self.encode_sentence(popular_answer)
        return answer, nouns, adjectives, popular_answer

    def get_negative_answers(self, sample, mode='mcq'):
        question_id = self.sample_to_question_dict[sample]
        positive_answers = []
        for answer in self.anno[question_id]['answers']:
            positive_answers.append(answer['answer'].lower())

        if mode=='mcq':
            multiple_choices = self.anno[question_id]['multiple_choices']
            remaining_answers = [
                ans.lower() for ans in multiple_choices if ans.lower() not in positive_answers]
            sampled_negative_answers = remaining_answers
        else:

            remaining_answers = [
                ans.lower() for ans in self.ans_vocab.keys() if ans.lower() not in positive_answers]
            sampled_negative_answers = np.random.choice(
                remaining_answers, 
                size=self.num_neg_answers,
                replace=False)

        remainder = self.num_neg_answers-len(sampled_negative_answers)
        for i in xrange(remainder):
            sampled_negative_answers.append(constants.unknown_token)

        encoded_answers = []                
        encoded_nouns = []
        encoded_adjectives = []
        for answer in sampled_negative_answers:
            nouns, adjectives = self.get_nouns_adjectives(answer)
            encoded_nouns.append(nouns)
            encoded_adjectives.append(adjectives)
            encoded_answers.append(self.encode_sentence(answer))
            
        return encoded_answers, encoded_nouns, encoded_adjectives, sampled_negative_answers

    def get_nouns_adjectives(self, sentence):
        words = nltk.tokenize.word_tokenize(sentence)
        nouns = []
        adjectives = []
        for word, pos_tag in nltk.pos_tag(words):
            if pos_tag in ['NN', 'NNS', 'NNP', 'NNPS']:
                nouns.append(self.lemmatizer.lemmatize(word.lower()))
            elif pos_tag in ['JJ', 'JJR', 'JJS']:
                adjectives.append(self.lemmatizer.lemmatize(word.lower()))
        # print 'Sentence: {}'.format(sentence)
        # print 'Nouns: {}'.format(nouns)
        # print 'Adjectives: {}'.format(adjectives)
        nouns = self.encode_sentence(' '.join(nouns))
        adjectives = self.encode_sentence(' '.join(adjectives))
        return nouns, adjectives

    def encode_sentence(self, sentence):
        # Split into words with only characters and numbers
        words = re.split('\W+',sentence.lower())
        
        # Remove ''
        words = [word for word in words if word!='']

        # If no words are left put an unknown_token
        if not words:
            words = [constants.unknown_token]

        encoded_sentence = []
        for word in words:
            if word not in self.vocab:
                word = constants.unknown_token
            encoded_sentence.append(int(self.vocab[word]))
            
        return encoded_sentence


if __name__=='__main__':
    data_mgr = data(
        constants.vqa_train_resnet_feat_dir,
        constants.vqa_train_anno,
        constants.vocab_json,
        constants.vqa_answer_vocab_json,
        constants.image_size,
        constants.num_region_proposals,
        constants.num_negative_answers)

    # for sample in xrange(10):
    #     print sample
    #     batch = data_mgr.get_parallel([sample])
    #     pdb.set_trace()
    batch = data_mgr.get_parallel(xrange(10))
    pdb.set_trace()
    
    # print 'Number of object labels: {}'.format(data_mgr.num_object_labels)
    # print 'Number of attribute labels: {}'.format(data_mgr.num_attribute_labels)
    # print 'Number of regions: {}'.format(data_mgr.num_regions)

    # #Test sample
    # samples = [1, 2]
    # sample = samples[0]
    # region_id = data_mgr.sample_to_region_dict[sample]
    # region = data_mgr.regions[region_id]
    # attribute_encoding = data_mgr.get_attribute_label(sample)
    # object_encoding = data_mgr.get_object_label(sample)
    # region_image = data_mgr.get_region_image(sample)

    # attributes = []
    # for i in xrange(attribute_encoding.shape[1]):
    #     if attribute_encoding[0,i] > 0 :
    #         attributes.append(data_mgr.inv_attribute_labels_dict[i])

    # objects = []
    # for i in xrange(object_encoding.shape[1]):
    #     if object_encoding[0,i] > 0 :
    #         objects.append(data_mgr.inv_object_labels_dict[i])
    
    # print "Region: {}".format(region)
    # print "Attributes: {}".format(", ".join(attributes))
    # print "Objects: {}".format(", ".join(objects))

    # batch_size = 200
    # num_samples = 200
    # num_epochs = 1
    # offset = 0
    # queue_size = 100

    # index_generator = tftools.data.random(
    #     batch_size, 
    #     num_samples, 
    #     num_epochs, 
    #     offset)
    
    # batch_generator = tftools.data.async_batch_generator(
    #     data_mgr, 
    #     index_generator, 
    #     queue_size)

    # count = 0 
    # for batch in batch_generator:
    #     print 'Batch Number: {}'.format(count)
    #     count += 1