Skip to content
Snippets Groups Projects
select_best_model.py 10.1 KiB
Newer Older
tgupta6's avatar
tgupta6 committed
# from word2vec.word_vector_management import word_vector_manager
# import object_attribute_classifier.inference as feature_graph 
# import region_relevance_network.inference as relevance_graph
# import answer_classifier.inference as answer_graph
from tftools import var_collect, placeholder_management
import tftools.data
import losses
import constants
import tftools.var_collect as var_collect
import data.vqa_cached_features as vqa_data
import answer_classifier_cached_features.train as train

import numpy as np
import pdb
import ujson
tgupta6's avatar
tgupta6 committed
import os
tgupta6's avatar
tgupta6 committed
import tensorflow as tf


def create_initializer(graph, sess, model):
    class initializer():
        def __init__(self):
            with graph.tf_graph.as_default():
                model_vars = graph.vars_to_save
                model_restorer = tf.train.Saver(model_vars)
                model_restorer.restore(sess, model)    
                not_to_init = model_vars
                all_vars = tf.all_variables()
                other_vars = [var for var in all_vars
                              if var not in not_to_init]
                var_collect.print_var_list(
                    other_vars,
                    'vars_to_init')
                self.init = tf.initialize_variables(other_vars)

        def initialize(self):
            sess.run(self.init)
    
    return initializer()

tgupta6's avatar
tgupta6 committed
def create_batch_generator():
tgupta6's avatar
tgupta6 committed
    data_mgr = vqa_data.data(
tgupta6's avatar
tgupta6 committed
        constants.vqa_train_resnet_feat_dir,
        constants.vqa_train_anno,
        constants.vqa_train_held_out_qids,
tgupta6's avatar
tgupta6 committed
        constants.vocab_json,
        constants.vqa_answer_vocab_json,
        constants.object_labels_json,
        constants.attribute_labels_json,
tgupta6's avatar
tgupta6 committed
        constants.image_size,
        constants.num_region_proposals,
        constants.num_negative_answers,
        resnet_feat_dim=constants.resnet_feat_dim)

tgupta6's avatar
tgupta6 committed
    num_train_held_out_questions = len(data_mgr.qids)    

tgupta6's avatar
tgupta6 committed
    index_generator = tftools.data.sequential(
tgupta6's avatar
tgupta6 committed
        constants.answer_batch_size, 
tgupta6's avatar
tgupta6 committed
        num_train_held_out_questions, 
tgupta6's avatar
tgupta6 committed
        1, 
tgupta6's avatar
tgupta6 committed
    
    batch_generator = tftools.data.async_batch_generator(
        data_mgr, 
        index_generator, 
        constants.answer_queue_size)
    
    return batch_generator


def create_feed_dict_creator(plh, num_neg_answers):
    def feed_dict_creator(batch):
        vqa_batch = batch
        batch_size = len(vqa_batch['question'])
        # Create vqa inputs
        inputs = {
            'region_feats': np.concatenate(vqa_batch['region_feats'], axis=0),
            'positive_answer': vqa_batch['positive_answer'],
        }
        for i in xrange(4):
            bin_name = 'bin_' + str(i)
            inputs[bin_name] = [
                vqa_batch['question'][j][bin_name] for j in xrange(batch_size)]
        
        for i in xrange(num_neg_answers):
            answer_name = 'negative_answer_' + str(i)
            inputs[answer_name] = [
                vqa_batch['negative_answers'][j][i] for j in xrange(batch_size)]

        inputs['positive_nouns'] = [
            a + b for a, b in zip(
                vqa_batch['question_nouns'],
                vqa_batch['positive_answer_nouns'])]

        inputs['positive_adjectives'] = [
            a + b for a, b in zip(
                vqa_batch['question_adjectives'],
                vqa_batch['positive_answer_adjectives'])]

        for i in xrange(num_neg_answers):
            name = 'negative_nouns_' + str(i)
            list_ith_negative_answer_nouns = [
                vqa_batch['negative_answers_nouns'][j][i]
                for j in xrange(batch_size)]
            inputs[name] = [
                a + b  for a, b in zip(
                    vqa_batch['question_nouns'],
                    list_ith_negative_answer_nouns)]
            
            name = 'negative_adjectives_' + str(i)
            list_ith_negative_answer_adjectives = [
                vqa_batch['negative_answers_adjectives'][j][i]
                for j in xrange(batch_size)]
            inputs[name] = [
                a + b for a, b in zip(
                    vqa_batch['question_adjectives'],
                    list_ith_negative_answer_adjectives)]
            
        inputs['keep_prob'] = 1.0

        return plh.get_feed_dict(inputs)

    return feed_dict_creator


class eval_mgr():
tgupta6's avatar
tgupta6 committed
    def __init__(self, results_json):
tgupta6's avatar
tgupta6 committed
        self.results_json = results_json
        self.correct = 0
        self.total = 0
        self.results = []
        self.seen_qids = set()
tgupta6's avatar
tgupta6 committed

    def eval(self, iter, eval_vars_dict, batch):
        batch_size = len(batch['question_unencoded'])
        
        for j in xrange(batch_size):
            question_id = batch['question_id'][j]
            pred_answer, pred_score = self.get_pred_answer(
                [batch['positive_answer_unencoded'][j]] + \
                batch['negative_answers_unencoded'][j],
                eval_vars_dict['answer_score_' + str(j)][0,:].tolist()
            )

            result_entry = {
                'question_id': int(question_id),
                'answer': pred_answer
            }

            if question_id not in self.seen_qids:
                self.seen_qids.add(question_id)
                self.results.append(result_entry)
            else:
                print 'Already evaluated on this sample'
tgupta6's avatar
tgupta6 committed
            

        self.total += batch_size
        
tgupta6's avatar
tgupta6 committed
        self.correct += eval_vars_dict['accuracy'][0]*batch_size
tgupta6's avatar
tgupta6 committed

tgupta6's avatar
tgupta6 committed
        if iter%10==0:
            self.print_accuracy()
tgupta6's avatar
tgupta6 committed
            self.write_data()

    def get_pred_answer(self, answers, scores):
        pred_answer = ''
        pred_score = -1e5
        for answer, score in zip(answers, scores):
            if score > pred_score:
                pred_score = score
                pred_answer = answer

        return pred_answer, pred_score

    def is_correct(self, answer_scores):
        max_id = np.argmax(answer_scores, 1)
        if max_id[0]==0:
            return True

    def print_accuracy(self):
        print 'Total: {}  Correct: {}  Accuracy: {}'.format(
            self.total,
            self.correct,
tgupta6's avatar
tgupta6 committed
            self.get_accuracy())
tgupta6's avatar
tgupta6 committed

tgupta6's avatar
tgupta6 committed
    def get_accuracy(self):
        return self.correct/float(self.total)
tgupta6's avatar
tgupta6 committed

tgupta6's avatar
tgupta6 committed
    def write_data(self):
tgupta6's avatar
tgupta6 committed
        with open(self.results_json, 'w') as file:
            ujson.dump(self.results, file, indent=4, sort_keys=True)
        
        
def eval(
        batch_generator, 
        sess, 
        initializer,
        vars_to_eval_dict,
        feed_dict_creator,
        evaluator):

    vars_to_eval_names = []
    vars_to_eval = []
    for var_name, var in vars_to_eval_dict.items():
        vars_to_eval_names += [var_name]
        vars_to_eval += [var]

    with sess.as_default():
        initializer.initialize()

        iter = 0
        for batch in batch_generator:
tgupta6's avatar
tgupta6 committed
            if iter%10 == 0:
                print '---'
                print 'Iter: {}'.format(iter)
tgupta6's avatar
tgupta6 committed
            feed_dict = feed_dict_creator(batch)
            eval_vars = sess.run(
                vars_to_eval,
                feed_dict = feed_dict)
            eval_vars_dict = {
                var_name: eval_var for var_name, eval_var in
                zip(vars_to_eval_names, eval_vars)}
            evaluator.eval(iter, eval_vars_dict, batch)
            iter+=1
        
tgupta6's avatar
tgupta6 committed
        print '---'
        print 'Iter: {}'.format(iter)
        evaluator.print_accuracy()
tgupta6's avatar
tgupta6 committed
        evaluator.write_data()


tgupta6's avatar
tgupta6 committed
def eval_model(model_to_eval, results_json):
tgupta6's avatar
tgupta6 committed
    print 'Creating batch generator...'
tgupta6's avatar
tgupta6 committed
    batch_generator = create_batch_generator()
tgupta6's avatar
tgupta6 committed

    print 'Creating computation graph...'
    graph = train.graph_creator(
        constants.tb_log_dir,
        constants.answer_batch_size,
        constants.image_size,
        constants.num_negative_answers,
        constants.answer_regularization_coeff,
        constants.answer_batch_size*constants.num_region_proposals,
        0,
        0,
        0,
        constants.answer_obj_atr_loss_wt,
        constants.answer_ans_loss_wt,
        constants.answer_mil_loss_wt,
tgupta6's avatar
tgupta6 committed
        resnet_feat_dim=constants.resnet_feat_dim,
        training=False)

    print 'Starting a session...'
    sess = tf.Session(graph=graph.tf_graph)

    print 'Creating initializer...'
    initializer = create_initializer(
        graph, 
        sess, 
tgupta6's avatar
tgupta6 committed
        model_to_eval)
tgupta6's avatar
tgupta6 committed

    print 'Creating feed dict creator...'
    feed_dict_creator = create_feed_dict_creator(
        graph.plh,
        constants.num_negative_answers)

    print 'Creating dict of vars to be evaluated...'
    vars_to_eval_dict = {
        'accuracy': graph.answer_accuracy,
    }
    for j in xrange(constants.answer_batch_size):
        vars_to_eval_dict['answer_score_'+str(j)] = \
            graph.answer_inference.answer_score[j]

    print 'Creating evaluation manager...'
tgupta6's avatar
tgupta6 committed
    evaluator = eval_mgr(results_json)
tgupta6's avatar
tgupta6 committed

    print 'Start training...'
    eval(
        batch_generator, 
        sess, 
        initializer,
        vars_to_eval_dict,
        feed_dict_creator,
        evaluator)

tgupta6's avatar
tgupta6 committed
    return evaluator.get_accuracy()
tgupta6's avatar
tgupta6 committed
def model_path_generator(models_dir, start_model, step_size):
    model_number = start_model
    filename = os.path.join(models_dir,'model-' + str(model_number))
    while os.path.exists(filename):
        yield filename, model_number
        model_number += step_size
        filename = os.path.join(models_dir,'model-' + str(model_number))
    
    
if __name__=='__main__':
    np.random.seed(seed=0)
    model_paths = model_path_generator(
        constants.models_dir,
        constants.start_model,
        constants.step_size)

    model_accuracies_txt = open(constants.model_accuracies_txt, 'w')
    best_model = (None, 0.0)
    for model_path, model_number in model_paths:
        results_json = os.path.join(
            constants.answer_output_dir,
            'eval_train_subset' + '_results_' + str(model_number) + '.json')
tgupta6's avatar
tgupta6 committed
        accuracy = eval_model(model_path, results_json)
        line = model_path + '\t' + str(accuracy)
        print line
        model_accuracies_txt.write(line + '\n')
        if accuracy > best_model[1]:
            best_model = (model_path, accuracy)
tgupta6's avatar
tgupta6 committed
            
tgupta6's avatar
tgupta6 committed
        print 'best_model:' + '\t' + best_model[0] + '\t' + str(best_model[1])

    line = 'best_model:' + '\t' + best_model[0] + '\t' + str(best_model[1])
    model_accuracies_txt.write(line)
    model_accuracies_txt.close()

tgupta6's avatar
tgupta6 committed