Skip to content
Snippets Groups Projects
train.py 34.42 KiB
from word2vec.word_vector_management import word_vector_manager
import object_attribute_classifier_cached_features.inference as feature_graph 
import region_relevance_network.inference as relevance_graph
import answer_classifier_cached_features.inference as answer_graph
from tftools import var_collect, placeholder_management
import tftools.train as multi_rate_train
import tftools.data
import losses
import constants
import tftools.var_collect as var_collect
import data.vqa_cached_features as vqa_data
import data.cropped_regions_cached_features as genome_data

import numpy as np
import pdb
from itertools import izip
import tensorflow as tf
import gc

class graph_creator():
    def __init__(
            self,
            tb_log_dir,
            batch_size,
            image_size,
            num_neg_answers,
            space_dim,
            regularization_coeff,
            num_regions_wo_labels,
            num_regions_w_labels,
            num_object_labels,
            num_attribute_labels,
            obj_atr_loss_wt,
            ans_loss_wt,
            mil_loss_wt,
            resnet_feat_dim=2048,
            training=True):
        self.im_h, self.im_w = image_size
        self.num_neg_answers = num_neg_answers
        self.space_dim = space_dim
        self.batch_size = batch_size
        self.regularization_coeff = regularization_coeff
        self.num_regions_wo_labels = num_regions_wo_labels
        self.num_regions_w_labels = num_regions_w_labels
        self.num_object_labels = num_object_labels
        self.num_attribute_labels = num_attribute_labels
        self.obj_atr_loss_wt = obj_atr_loss_wt
        self.ans_loss_wt = ans_loss_wt
        self.mil_loss_wt = mil_loss_wt
        self.resnet_feat_dim = resnet_feat_dim
        self.training = training
        self.tf_graph = tf.Graph()
        with self.tf_graph.as_default():
            self.create_placeholders()

            self.word_vec_mgr = word_vector_manager()

            if self.training:
                self.concat_feats = tf.concat(
                    0, 
                    [self.plh['region_feats'], self.plh['region_feats_with_labels']])
            else:
                self.concat_feats = self.plh['region_feats']
        
            self.obj_atr_inference = feature_graph.ObjectAttributeInference(
                self.concat_feats,
                self.word_vec_mgr.object_label_vectors,
                self.word_vec_mgr.attribute_label_vectors,
                training)

            if self.training:
                self.split_obj_atr_inference_output()
                self.object_feat = self.object_embed_with_answers
                self.attribute_feat = self.attribute_embed_with_answers
            else:
                self.object_feat = self.obj_atr_inference.object_embed
                self.attribute_feat = self.obj_atr_inference.attribute_embed

            self.object_feat = tf.split(
                0,
                self.batch_size,
                self.object_feat)

            self.attribute_feat = tf.split(
                0,
                self.batch_size,
                self.attribute_feat)



            self.question_embed, self.question_embed_concat = \
                self.get_question_embeddings()
            self.answers_embed, self.answers_embed_concat = \
                self.get_answer_embeddings()
            self.noun_embed, self.adjective_embed = \
                self.get_noun_adjective_embeddings()
            self.relevance_inference = \
                relevance_graph.RegionRelevanceInference(
                    self.batch_size,
                    self.object_feat,
                    self.attribute_feat,
                    self.noun_embed,
                    self.adjective_embed
                )

            self.answer_inference = answer_graph.AnswerInference(
                self.object_feat,
                self.attribute_feat,
                self.relevance_inference.answer_region_prob,
                self.question_embed_concat,
                self.answers_embed_concat,
                self.noun_embed,
                self.adjective_embed,
                self.num_neg_answers + 1,
                self.space_dim,
                self.plh['keep_prob'],
                self.training)

            self.add_losses()
            self.add_accuracy_computation()
            self.collect_variables()
            self.vars_to_save = tf.all_variables()
            self.merged = tf.merge_all_summaries()
            self.writer = tf.train.SummaryWriter(
                tb_log_dir,
                graph = self.tf_graph)

    def create_placeholders(self):
        self.plh = placeholder_management.PlaceholderManager()

        self.plh.add_placeholder(
            'keep_prob',
            tf.float32,
            shape=[])

        self.plh.add_placeholder(
            'region_feats',
            tf.float32,
            shape=[None, self.resnet_feat_dim])

        if self.training:
            self.plh.add_placeholder(
                'region_feats_with_labels',
                tf.float32,
                shape=[None, self.resnet_feat_dim])

            self.plh.add_placeholder(
                'object_labels',
                tf.float32, 
                shape=[None, self.num_object_labels])

            self.plh.add_placeholder(
                'attribute_labels',
                tf.float32,
                shape=[None, self.num_attribute_labels])        
            
        for i in xrange(self.num_neg_answers):
            answer_name = 'negative_answer_' + str(i)
            self.plh.add_placeholder(
                answer_name,
                tf.int64,
                shape=[None],
                size=self.batch_size)
            
        self.plh.add_placeholder(
            'positive_answer',
            tf.int64,
            shape=[None],
            size=self.batch_size)

        for i in xrange(4):
            bin_name = 'bin_' + str(i)
            self.plh.add_placeholder(
                bin_name,
                tf.int64,
                shape=[None],
                size=self.batch_size)
            
        self.plh.add_placeholder(
            'positive_nouns',
            tf.int64,
            shape=[None],
            size=self.batch_size)

        self.plh.add_placeholder(
            'positive_adjectives',
            tf.int64,
            shape=[None],
            size=self.batch_size)
        
        for i in xrange(self.num_neg_answers):
            self.plh.add_placeholder(
                'negative_nouns_' + str(i),
                tf.int64,
                shape=[None],
                size=self.batch_size)

            self.plh.add_placeholder(
                'negative_adjectives_' + str(i),
                tf.int64,
                shape=[None],
                size=self.batch_size)

        self.plh.add_placeholder(
            'positive_nouns_vec_enc',
            tf.float32,
            shape=[1,None],
            size=self.batch_size)

        self.plh.add_placeholder(
            'positive_adjectives_vec_enc',
            tf.float32,
            shape=[1,None],
            size=self.batch_size)

    def get_noun_adjective_embeddings(self):
        with tf.variable_scope('noun_adjective_embed'):
            noun_embed = dict()
            adjective_embed = dict()

            name = 'positive_nouns'
            noun_embed[name] = []
            for j in xrange(self.batch_size):
                embed = tf.nn.embedding_lookup(
                    self.word_vec_mgr.word_vectors,
                    self.plh[name][j],
                    name = 'embedding_lookup_' + name)
                noun_embed[name].append(embed)
                
            for i in xrange(self.num_neg_answers):
                name = 'negative_nouns_' + str(i)
                noun_embed[name] = []
                for j in xrange(self.batch_size):
                    embed = tf.nn.embedding_lookup(
                        self.word_vec_mgr.word_vectors,
                        self.plh[name][j],
                        name = 'embedding_lookup_' + name)
                    noun_embed[name].append(embed)

            name = 'positive_adjectives'
            adjective_embed[name] = []
            for j in xrange(self.batch_size):
                embed = tf.nn.embedding_lookup(
                    self.word_vec_mgr.word_vectors,
                    self.plh[name][j],
                    name = 'embedding_lookup_' + name)
                adjective_embed[name].append(embed)

            for i in xrange(self.num_neg_answers):
                name = 'negative_adjectives_' + str(i) 
                adjective_embed[name] = []
                for j in xrange(self.batch_size):
                    embed = tf.nn.embedding_lookup(
                        self.word_vec_mgr.word_vectors,
                        self.plh[name][j],
                        name = 'embedding_lookup_' + name)
                    adjective_embed[name].append(embed)
            
        return noun_embed, adjective_embed
                
    def get_question_embeddings(self):
        with tf.variable_scope('question_bin_embed'):
            question_bin_embed = dict()
            tensor_list = [[] for i in xrange(self.batch_size)]
            for i in xrange(4):
                bin_name = 'bin_' + str(i)
                question_bin_embed[bin_name] = []
                for j in xrange(self.batch_size):
                    embed = self.lookup_word_embeddings(
                        self.plh[bin_name][j],
                        bin_name)
                    question_bin_embed[bin_name].append(embed)
                    tensor_list[j].append(embed)            
            
            question_bin_embed_concat = []
            for j in xrange(self.batch_size):
                embed_concat = tf.concat(
                    0,
                    tensor_list[j],
                    name = 'concat_question_bins')
                question_bin_embed_concat.append(embed_concat)

        return question_bin_embed, question_bin_embed_concat

    def get_answer_embeddings(self):
        with tf.variable_scope('answers_embed'):
            answers_embed = dict()
            tensor_list = [[] for i in xrange(self.batch_size)]
                           
            answer_name = 'positive_answer'
            answers_embed[answer_name] = []
            for j in xrange(self.batch_size):
                embed = self.lookup_word_embeddings(
                    self.plh[answer_name][j],
                    answer_name)
                answers_embed[answer_name].append(embed) 
                tensor_list[j].append(embed)

            for i in xrange(self.num_neg_answers):
                answer_name = 'negative_answer_' + str(i)
                answers_embed[answer_name] = []
                for j in xrange(self.batch_size):
                    embed = self.lookup_word_embeddings(
                        self.plh[answer_name][j],
                        answer_name)
                    answers_embed[answer_name].append(embed)
                    tensor_list[j].append(embed)
                
            answers_embed_concat = []
            for j in xrange(self.batch_size):
                embed_concat = tf.concat(
                    0,
                    tensor_list[j],
                    name = 'concat_answers')
                answers_embed_concat.append(embed_concat)

        return answers_embed, answers_embed_concat
        
    def lookup_word_embeddings(self, index_list, name):
        with tf.variable_scope(name):
            word_vectors = tf.nn.embedding_lookup(
                self.word_vec_mgr.word_vectors,
                index_list,
                name = 'embedding_lookup')

            embedding = tf.reduce_mean(
                word_vectors, 
                0,
                True,
                'reduce_mean')

        return embedding
                
    def split_obj_atr_inference_output(self):
        with tf.variable_scope('split'):
            self.object_embed_with_answers = tf.slice(
                self.obj_atr_inference.object_embed,
                [0, 0],
                [self.num_regions_wo_labels, -1])
            
            self.object_scores_with_labels = tf.slice(
                self.obj_atr_inference.object_scores,
                [self.num_regions_wo_labels, 0],
                [-1, -1])

            self.object_scores_with_answers = tf.slice(
                self.obj_atr_inference.object_scores,
                [0,0],
                [self.num_regions_wo_labels, -1])

            self.object_scores_with_answers = tf.split(
                0,
                self.batch_size,
                self.object_scores_with_answers)

            self.attribute_embed_with_answers = tf.slice(
                self.obj_atr_inference.attribute_embed,
                [0, 0],
                [self.num_regions_wo_labels, -1])

            self.attribute_scores_with_labels = tf.slice(
                self.obj_atr_inference.attribute_scores,
                [self.num_regions_wo_labels, 0],
                [-1, -1])

            self.attribute_scores_with_answers = tf.slice(
                self.obj_atr_inference.attribute_scores,
                [0,0],
                [self.num_regions_wo_labels, -1])
            
            self.attribute_scores_with_answers = tf.split(
                0,
                self.batch_size,
                self.attribute_scores_with_answers)

    def add_losses(self):
        y = np.zeros([1, self.num_neg_answers + 1])
        y[0,0] = 1.0
        y = tf.constant(y, dtype=tf.float32)
        self.answer_loss = 0
        for j in xrange(self.batch_size):
            self.answer_loss += losses.answer_loss(
                self.answer_inference.answer_score[j],
                y)
        self.answer_loss /= self.batch_size
        self.answer_loss *= self.ans_loss_wt

        if self.training:
            self.object_loss = self.obj_atr_loss_wt*losses.object_loss(
                #self.obj_atr_inference.object_scores,
                self.object_scores_with_labels, 
                self.plh['object_labels'])
                #object_labels)
                
            object_loss_summary = tf.scalar_summary(
                "loss_object", 
                self.object_loss)

            self.attribute_loss = 1000*self.obj_atr_loss_wt*losses.attribute_loss(
                #self.obj_atr_inference.attribute_scores,
                self.attribute_scores_with_labels, 
                self.plh['attribute_labels'],
                #attribute_labels,
                self.num_regions_w_labels)

            attribute_loss_summary = tf.scalar_summary(
                "loss_attribute", 
                self.attribute_loss)
            
            # depends on variables that are available only during training
            self.mil_obj_loss = 0.0
            self.mil_atr_loss = 0.0
            for j in xrange(self.batch_size):
                self.mil_obj_loss += losses.mil_loss(
                    self.object_scores_with_answers[j],
                    self.plh['positive_nouns_vec_enc'][j],
                    'obj')

                self.mil_atr_loss += losses.mil_loss(
                    self.attribute_scores_with_answers[j],
                    self.plh['positive_adjectives_vec_enc'][j],
                    'atr')

            self.mil_obj_loss = self.mil_loss_wt*self.mil_obj_loss / self.batch_size
            self.mil_atr_loss = self.mil_loss_wt*self.mil_atr_loss / self.batch_size
            
            mil_obj_loss_summary = tf.scalar_summary(
                "loss_mil_obj", 
                self.mil_obj_loss)

            mil_atr_loss_summary = tf.scalar_summary(
                "loss_mil_atr", 
                self.mil_atr_loss)
        else:
            self.object_loss = 0.0
            self.attribute_loss = 0.0
            self.mil_obj_loss = 0.0
            self.mil_atr_loss = 0.0

        self.regularization_loss = self.regularization()

        self.total_loss = self.object_loss + self.attribute_loss + \
                          self.regularization_loss + \
                          self.answer_loss + \
                          self.mil_obj_loss + self.mil_atr_loss
                          
        ema = tf.train.ExponentialMovingAverage(0.95, name='ema')
        update_op = ema.apply([self.answer_loss])
        moving_average_answer_loss = ema.average(self.answer_loss)
        with tf.control_dependencies([update_op]):
            answer_loss_summary = tf.scalar_summary(
                "loss_answer", 
                moving_average_answer_loss)

        regularization_loss_summary = tf.scalar_summary(
            "loss_regularization", 
            self.regularization_loss)

        total_loss_summary = tf.scalar_summary(
            "loss_total",
            self.total_loss)

    def regularization(self):
        vars_to_regularize = tf.get_collection('to_regularize')
        loss = losses.regularization_loss(
            vars_to_regularize,
            self.regularization_coeff)

        return loss

    def add_accuracy_computation(self):
        with tf.variable_scope('accuracy_graph'):
            # answer
            self.answer_accuracy, self.answer_accuracy_ema, \
                self.update_answer_accuracy_op = \
                    self.add_answer_accuracy_computation(
                        self.answer_inference.answer_score)
            
            self.moving_average_accuracy = self.answer_accuracy_ema.average(
                self.answer_accuracy)

            with tf.control_dependencies([self.update_answer_accuracy_op]):
                answer_accuracy_summary = tf.scalar_summary(
                    ["accuracy_answer"],
                    self.moving_average_accuracy)

            if self.training:
                # object
                self.object_accuracy = self.add_object_accuracy_computation(
                    self.object_scores_with_labels,
                    self.plh['object_labels'])

                object_accuracy_summary = tf.scalar_summary(
                    "accuracy_object", 
                    self.object_accuracy)

                # attributes
                self.attribute_accuracy = self.add_attribute_accuracy_computation(
                    self.attribute_scores_with_labels,
                    self.plh['attribute_labels'])

                attribute_accuracy_summary = tf.scalar_summary(
                    "accuracy_attribute", 
                    self.attribute_accuracy)
            
    def add_answer_accuracy_computation(self, scores):
        with tf.variable_scope('answer_accuracy'):
            accuracy = 0.0
            for j in xrange(self.batch_size):
                is_correct = tf.equal(
                    tf.argmax(scores[j],1), 
                    tf.constant(0,dtype=tf.int64))
                accuracy += tf.cast(is_correct, tf.float32) 
            accuracy /= self.batch_size
            ema = tf.train.ExponentialMovingAverage(0.95, name='ema')
            update_accuracy_op = ema.apply([accuracy])

        return accuracy, ema, update_accuracy_op

    def add_object_accuracy_computation(self, scores, labels):
        with tf.variable_scope('object_accuracy'):
            correct_prediction = tf.equal(
                tf.argmax(scores, 1), 
                tf.argmax(labels, 1), 
                name='correct_prediction')

            object_accuracy =  tf.reduce_mean(
                tf.cast(correct_prediction, tf.float32), 
                name='accuracy')

        return object_accuracy

    def add_attribute_accuracy_computation(self, scores, labels):
        with tf.variable_scope('object_accuracy'):
            thresholded = tf.greater(
                scores, 
                0.0, 
                name='thresholded')

            correct_prediction = tf.equal(
                thresholded,
                tf.cast(labels, tf.bool),
                name = 'correct_prediction')

            attribute_accuracy = tf.reduce_mean(
                tf.cast(correct_prediction, tf.float32), 
                name='accuracy')

        return attribute_accuracy

    def collect_variables(self):
        self.word_vec_vars = var_collect.collect_scope('word_vectors')
        self.resnet_vars = self.obj_atr_inference.resnet_vars
        self.object_attribute_vars = \
            var_collect.collect_scope('object_graph') + \
            var_collect.collect_scope('attribute_graph')
            #var_collect.collect_scope('bn')
        self.answer_vars = var_collect.collect_scope('answer_graph')

def create_initializer(graph, sess, model):
    class initializer():
        def __init__(self):
            with graph.tf_graph.as_default():
                
                model_vars = graph.word_vec_vars + \
                             graph.object_attribute_vars
                
                model_restorer = tf.train.Saver(model_vars)
                model_restorer.restore(sess, model)    
                not_to_init = model_vars
                all_vars = tf.all_variables()
                other_vars = [var for var in all_vars
                              if var not in not_to_init]
                var_collect.print_var_list(
                    other_vars,
                    'vars_to_init')
                self.init = tf.initialize_variables(other_vars)

        def initialize(self):
            sess.run(self.init)
    
    return initializer()

def create_scratch_initializer(graph, sess):
    class initializer():
        def __init__(self):
            with graph.tf_graph.as_default():
                all_vars = tf.all_variables()
                var_collect.print_var_list(
                    all_vars,
                    'vars_to_init')
                self.init = tf.initialize_variables(all_vars)

        def initialize(self):
            sess.run(self.init)
    
    return initializer()

def create_vqa_batch_generator():
    data_mgr = vqa_data.data(
        constants.vqa_train_resnet_feat_dir,
        constants.vqa_train_anno,
        constants.vqa_train_subset_qids,
        constants.vocab_json,
        constants.vqa_answer_vocab_json,
        constants.object_labels_json,
        constants.attribute_labels_json,
        constants.image_size,
        constants.num_region_proposals,
        constants.num_negative_answers,
        resnet_feat_dim=constants.resnet_feat_dim)

    num_train_subset_questions = len(data_mgr.qids)
    print num_train_subset_questions

    index_generator = tftools.data.random(
        constants.answer_batch_size, 
        num_train_subset_questions, 
        constants.answer_num_epochs, 
        0)
    
    batch_generator = tftools.data.async_batch_generator(
        data_mgr, 
        index_generator, 
        constants.answer_queue_size)
    
    return batch_generator


def create_vgenome_batch_generator():
    data_mgr = genome_data.data(
        constants.genome_resnet_feat_dir,
        constants.image_dir,
        constants.object_labels_json,
        constants.attribute_labels_json,
        constants.regions_json,
        constants.genome_train_region_ids,
        constants.image_size,
        channels=3,
        resnet_feat_dim=constants.resnet_feat_dim,
        mean_image_filename=None)

    num_train_regions = len(data_mgr.region_ids)
    print num_train_regions
    
    index_generator = tftools.data.random(
        constants.num_regions_with_labels, 
        num_train_regions,
        constants.region_num_epochs, 
        0)
    
    batch_generator = tftools.data.async_batch_generator(
        data_mgr, 
        index_generator, 
        constants.region_queue_size)
    
    return batch_generator


def create_batch_generator():
    vqa_generator = create_vqa_batch_generator()
    vgenome_generator = create_vgenome_batch_generator()
    generator = izip(vqa_generator, vgenome_generator)
    return generator


class attach_optimizer():
    def __init__(self, graph, lr, decay_step=24000, decay_rate=0.5):
        self.graph = graph
        self.lr = lr
        self.decay_step = decay_step
        self.decay_rate = decay_rate
        with graph.tf_graph.as_default():
            all_trainable_vars = tf.trainable_variables()

            self.not_to_train = []#graph.object_attribute_vars + graph.word_vec_vars

            vars_to_train = [
                var for var in all_trainable_vars
                if var not in self.not_to_train]
            #vars_to_train = graph.resnet_vars

            var_collect.print_var_list(
                vars_to_train,
                'vars_to_train')

            all_vars = tf.all_variables()
            self.ops = dict()

            self.global_step = tf.Variable(0, trainable=False)
            self.learning_rate = tf.train.exponential_decay(
                self.lr, 
                self.global_step,
                self.decay_step,
                self.decay_rate)
            
            self.optimizer = multi_rate_train.MultiRateOptimizer(
                tf.train.AdamOptimizer)

            # self.optimizer.add_variables(
            #     self.graph.object_attribute_vars + self.graph.word_vec_vars,
            #     learning_rate = 1.0*self.lr)

            
            self.optimizer.add_variables(
                vars_to_train,
                learning_rate = self.learning_rate)

            self.train_op = self.optimizer.minimize(
                graph.total_loss,
                self.global_step)
                
            # self.add_adam_optimizer(
            #     graph.total_loss,
            #     vars_to_train,
            #     'optimizer')

            # self.train_op = self.group_all_train_ops()
            
            all_vars_with_opt_vars = tf.all_variables()
            self.opt_vars = [var for var in all_vars_with_opt_vars if var not in all_vars]

    def filter_out_vars_to_train(self, var_list):
        return [var for var in var_list if var not in self.not_to_train]

    
    def add_adam_optimizer(self, loss, var_list, name):
        var_list = self.filter_out_vars_to_train(var_list)
        if not var_list:
            self.ops[name] = []
            return

        train_step = tf.train.AdamOptimizer(self.lr) \
                             .minimize(
                                 loss, 
                                 var_list = var_list)
        
        
        self.ops[name] = train_step

    def group_all_train_ops(self):
        train_op = tf.group()
        for op in self.ops.values():
            if op:
                train_op = tf.group(train_op, op)

#        check_op = tf.add_check_numerics_ops()
#        train_op = tf.group(train_op, check_op)
        return train_op


def create_feed_dict_creator(plh, num_neg_answers):
    def feed_dict_creator(batch):
        vqa_batch, vgenome_batch = batch
        batch_size = len(vqa_batch['question'])
        # Create vqa inputs
        inputs = {
            'region_feats': np.concatenate(vqa_batch['region_feats'], axis=0),
            'positive_answer': vqa_batch['positive_answer'],
        }
        for i in xrange(4):
            bin_name = 'bin_' + str(i)
            inputs[bin_name] = [
                vqa_batch['question'][j][bin_name] for j in xrange(batch_size)]
        
        for i in xrange(num_neg_answers):
            answer_name = 'negative_answer_' + str(i)
            inputs[answer_name] = [
                vqa_batch['negative_answers'][j][i] for j in xrange(batch_size)]

        inputs['positive_nouns'] = vqa_batch['positive_nouns']
        inputs['positive_adjectives'] = vqa_batch['positive_adjectives']

        for i in xrange(num_neg_answers):
            name = 'negative_nouns_' + str(i)
            list_ith_negative_answer_nouns = [
                vqa_batch['negative_answers_nouns'][j][i]
                for j in xrange(batch_size)]
            inputs[name] = [
                a + b  for a, b in zip(
                    vqa_batch['question_nouns'],
                    list_ith_negative_answer_nouns)]
            
            name = 'negative_adjectives_' + str(i)
            list_ith_negative_answer_adjectives = [
                vqa_batch['negative_answers_adjectives'][j][i]
                for j in xrange(batch_size)]
            inputs[name] = [
                a + b for a, b in zip(
                    vqa_batch['question_adjectives'],
                    list_ith_negative_answer_adjectives)]

        inputs['positive_nouns_vec_enc'] = vqa_batch['positive_nouns_vec_enc']
        inputs['positive_adjectives_vec_enc'] = vqa_batch['positive_adjectives_vec_enc']
        
        # Create vgenome inputs
        inputs['region_feats_with_labels'] = vgenome_batch['region_feats']
        inputs['object_labels'] = vgenome_batch['object_labels']
        inputs['attribute_labels'] = vgenome_batch['attribute_labels']
    
        inputs['keep_prob'] = 0.8
        return plh.get_feed_dict(inputs)

    return feed_dict_creator


class log_mgr():
    def __init__(
            self, 
            graph,
            vars_to_save, 
            sess, 
            log_every_n_iter,
            output_dir,
            model_path):
        self.graph = graph
        self.vars_to_save = vars_to_save
        self.sess = sess
        self.log_every_n_iter = log_every_n_iter
        self.output_dir = output_dir
        self.model_path = model_path

        self.model_saver = tf.train.Saver(
            var_list = vars_to_save,
            max_to_keep = 0)

        self.loss_values = dict()

    def log(self, iter, is_last=False, eval_vars_dict=None):
        if eval_vars_dict:
            self.graph.writer.add_summary(
                eval_vars_dict['merged'], 
                iter)

            print 'Word Vector shape: {}'.format(
                eval_vars_dict['word_vectors'].shape)
            print np.max(eval_vars_dict['word_vectors'])
            print np.min(eval_vars_dict['word_vectors'])

            print 'Object Scores shape: {}'.format(
                eval_vars_dict['object_scores'].shape)
            print np.max(eval_vars_dict['object_scores'])

            print 'Attribute Scores shape: {}'.format(
                eval_vars_dict['attribute_scores'].shape)
            print np.max(eval_vars_dict['attribute_scores'])

            print 'Answer Scores shape: {}'.format(
                eval_vars_dict['answer_scores'].shape)
            print np.max(eval_vars_dict['answer_scores'])

            print 'Relevance Prob shape: {}'.format(
                eval_vars_dict['relevance_prob'].shape)
            print np.max(eval_vars_dict['relevance_prob'])

            print 'Per region answer prob shape: {}'.format(
                eval_vars_dict['per_region_answer_prob'].shape)
            print np.max(eval_vars_dict['per_region_answer_prob'])

            print 'Learning Rate: {}'.format(eval_vars_dict['lr'])

        if (iter % self.log_every_n_iter==0 or is_last) and (iter!=0):
            self.model_saver.save(
                self.sess, 
                self.model_path, 
                global_step=iter)
            

def train(
        batch_generator, 
        sess, 
        initializer,
        vars_to_eval_dict,
        feed_dict_creator,
        logger):

    vars_to_eval_names = []
    vars_to_eval = []
    for var_name, var in vars_to_eval_dict.items():
        vars_to_eval_names += [var_name]
        vars_to_eval += [var]

    with sess.as_default():
        initializer.initialize()

        iter = 0
        for batch in batch_generator:
            print '---'
            print 'Iter: {}'.format(iter)
            feed_dict = feed_dict_creator(batch)
            eval_vars = sess.run(
                vars_to_eval,
                feed_dict = feed_dict)
            eval_vars_dict = {
                var_name: eval_var for var_name, eval_var in
                zip(vars_to_eval_names, eval_vars)}
            logger.log(iter, False, eval_vars_dict)
            iter+=1

            if iter%8000==0 and iter!=0:
                gc.collect()

        logger.log(iter-1, True, eval_vars_dict)


if __name__=='__main__':
    print 'Creating batch generator...'
    batch_generator = create_batch_generator()

    print 'Creating computation graph...'
    graph = graph_creator(
        constants.tb_log_dir,
        constants.answer_batch_size,
        constants.image_size,
        constants.num_negative_answers,
        constants.answer_embedding_dim,
        constants.answer_regularization_coeff,
        constants.answer_batch_size*constants.num_region_proposals,
        constants.num_regions_with_labels,
        constants.num_object_labels,
        constants.num_attribute_labels,
        constants.answer_obj_atr_loss_wt,
        constants.answer_ans_loss_wt,
        constants.answer_mil_loss_wt,
        resnet_feat_dim=constants.resnet_feat_dim,
        training=True)

    print 'Attaching optimizer...'
    optimizer = attach_optimizer(
        graph, 
        constants.answer_lr)

    print 'Starting a session...'
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config, graph=graph.tf_graph)

    print 'Creating initializer...'
    if constants.answer_train_from_scratch:
        initializer = create_scratch_initializer(
            graph, 
            sess)
    else:
        initializer = create_initializer(
            graph, 
            sess, 
            constants.pretrained_model)
        

    print 'Creating feed dict creator...'
    feed_dict_creator = create_feed_dict_creator(
        graph.plh,
        constants.num_negative_answers)

    print 'Creating dict of vars to be evaluated...'
    vars_to_eval_dict = {
        'optimizer_op': optimizer.train_op,
        'word_vectors': graph.word_vec_mgr.word_vectors, 
        'relevance_prob': graph.relevance_inference.answer_region_prob[0],
        'per_region_answer_prob': graph.answer_inference.per_region_answer_prob[0],
        'object_scores': graph.obj_atr_inference.object_scores,
        'attribute_scores': graph.obj_atr_inference.attribute_scores,
        'answer_scores': graph.answer_inference.answer_score[0],
        'accuracy': graph.moving_average_accuracy,
        'total_loss': graph.total_loss,
        'lr': optimizer.learning_rate,
        # 'question_embed_concat': graph.question_embed_concat,
        # 'answer_embed_concat': graph.answers_embed_concat,
        # 'noun_embed': graph.noun_embed['positive_nouns'],
        # 'adjective_embed': graph.adjective_embed['positive_adjectives'],
#        'assert': graph.answer_inference.assert_op,
        'merged': graph.merged,
    }

    print 'Creating logger...'
    vars_to_save = graph.vars_to_save
    logger = log_mgr(
        graph,
        graph.vars_to_save, 
        sess, 
        constants.answer_log_every_n_iter,
        constants.answer_output_dir,
        constants.answer_model)

    print 'Start training...'
    train(
        batch_generator, 
        sess, 
        initializer,
        vars_to_eval_dict,
        feed_dict_creator,
        logger)