Skip to content
Snippets Groups Projects
fine_tune.py 5.07 KiB
# from word2vec.word_vector_management import word_vector_manager
# import object_attribute_classifier_cached_features.inference as feature_graph 
# import region_relevance_network.inference as relevance_graph
# import answer_classifier_cached_features.inference as answer_graph
from tftools import var_collect, placeholder_management
import tftools.data
import losses
import constants
import tftools.var_collect as var_collect
import data.vqa
import answer_classifier_cached_features.train as train

import numpy as np
import pdb
import tensorflow as tf


def create_initializer(graph, sess, model):
    class initializer():
        def __init__(self):
            with graph.tf_graph.as_default():
                model_vars = graph.vars_to_save
                model_restorer = tf.train.Saver(model_vars)
                model_restorer.restore(sess, model)    
                not_to_init = model_vars
                all_vars = tf.all_variables()
                other_vars = [var for var in all_vars
                              if var not in not_to_init]
                var_collect.print_var_list(
                    other_vars,
                    'vars_to_init')
                self.init = tf.initialize_variables(other_vars)

        def initialize(self):
            sess.run(self.init)
    
    return initializer()


def fine_tune(
        fine_tune_from_iter,
        batch_generator, 
        sess, 
        initializer,
        vars_to_eval_dict,
        feed_dict_creator,
        logger):

    vars_to_eval_names = []
    vars_to_eval = []
    for var_name, var in vars_to_eval_dict.items():
        vars_to_eval_names += [var_name]
        vars_to_eval += [var]

    with sess.as_default():
        initializer.initialize()
        iter = fine_tune_from_iter + 1
        for batch in batch_generator:
            print '---'
            print 'Iter: {}'.format(iter)
            feed_dict = feed_dict_creator(batch)
            eval_vars = sess.run(
                vars_to_eval,
                feed_dict = feed_dict)
            eval_vars_dict = {
                var_name: eval_var for var_name, eval_var in
                zip(vars_to_eval_names, eval_vars)}
            logger.log(iter, False, eval_vars_dict)
            iter+=1
        
        logger.log(iter-1, True, eval_vars_dict)


if __name__=='__main__':
    print 'Creating batch generator...'
    batch_generator = train.create_batch_generator()

    print 'Creating computation graph...'
    graph = train.graph_creator(
        constants.tb_log_dir,
        constants.answer_batch_size,
        constants.image_size,
        constants.num_negative_answers,
        constants.answer_embedding_dim,
        constants.answer_regularization_coeff,
        constants.answer_batch_size*constants.num_region_proposals,
        constants.num_regions_with_labels,
        constants.num_object_labels,
        constants.num_attribute_labels,
        constants.answer_obj_atr_loss_wt,
        resnet_feat_dim=constants.resnet_feat_dim,
        training=True)

    print 'Attaching optimizer...'
    optimizer = train.attach_optimizer(
        graph, 
        constants.answer_lr)

    print 'Starting a session...'
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.8
    sess = tf.Session(config=config, graph=graph.tf_graph)

    print 'Creating initializer...'
    initializer = create_initializer(
        graph, 
        sess, 
        constants.answer_fine_tune_from)

    print 'Creating feed dict creator...'
    feed_dict_creator = train.create_feed_dict_creator(
        graph.plh,
        constants.num_negative_answers)

    print 'Creating dict of vars to be evaluated...'
    vars_to_eval_dict = {
        'optimizer_op': optimizer.train_op,
        'word_vectors': graph.word_vec_mgr.word_vectors, 
        'relevance_prob': graph.relevance_inference.answer_region_prob[0],
        'per_region_answer_prob': graph.answer_inference.per_region_answer_prob[0],
        'object_scores': graph.obj_atr_inference.object_scores,
        'attribute_scores': graph.obj_atr_inference.attribute_scores,
        'answer_scores': graph.answer_inference.answer_score[0],
        'accuracy': graph.moving_average_accuracy,
        'total_loss': graph.total_loss,
        # 'question_embed_concat': graph.question_embed_concat,
        # 'answer_embed_concat': graph.answers_embed_concat,
        # 'noun_embed': graph.noun_embed['positive_nouns'],
        # 'adjective_embed': graph.adjective_embed['positive_adjectives'],
#        'assert': graph.answer_inference.assert_op,
        'merged': graph.merged,
    }

    print 'Creating logger...'
    vars_to_save = graph.vars_to_save
    logger = train.log_mgr(
        graph,
        graph.vars_to_save, 
        sess, 
        constants.answer_log_every_n_iter,
        constants.answer_output_dir,
        constants.answer_model)

    print 'Start training...'
    fine_tune(
        constants.answer_fine_tune_from_iter,
        batch_generator, 
        sess, 
        initializer,
        vars_to_eval_dict,
        feed_dict_creator,
        logger)