# from word2vec.word_vector_management import word_vector_manager # import object_attribute_classifier_cached_features.inference as feature_graph # import region_relevance_network.inference as relevance_graph # import answer_classifier_cached_features.inference as answer_graph from tftools import var_collect, placeholder_management import tftools.data import losses import constants import tftools.var_collect as var_collect import data.vqa import answer_classifier_cached_features.train as train import numpy as np import pdb import tensorflow as tf def create_initializer(graph, sess, model): class initializer(): def __init__(self): with graph.tf_graph.as_default(): model_vars = graph.vars_to_save model_restorer = tf.train.Saver(model_vars) model_restorer.restore(sess, model) not_to_init = model_vars all_vars = tf.all_variables() other_vars = [var for var in all_vars if var not in not_to_init] var_collect.print_var_list( other_vars, 'vars_to_init') self.init = tf.initialize_variables(other_vars) def initialize(self): sess.run(self.init) return initializer() def fine_tune( fine_tune_from_iter, batch_generator, sess, initializer, vars_to_eval_dict, feed_dict_creator, logger): vars_to_eval_names = [] vars_to_eval = [] for var_name, var in vars_to_eval_dict.items(): vars_to_eval_names += [var_name] vars_to_eval += [var] with sess.as_default(): initializer.initialize() iter = fine_tune_from_iter + 1 for batch in batch_generator: print '---' print 'Iter: {}'.format(iter) feed_dict = feed_dict_creator(batch) eval_vars = sess.run( vars_to_eval, feed_dict = feed_dict) eval_vars_dict = { var_name: eval_var for var_name, eval_var in zip(vars_to_eval_names, eval_vars)} logger.log(iter, False, eval_vars_dict) iter+=1 logger.log(iter-1, True, eval_vars_dict) if __name__=='__main__': print 'Creating batch generator...' batch_generator = train.create_batch_generator() print 'Creating computation graph...' graph = train.graph_creator( constants.tb_log_dir, constants.answer_batch_size, constants.image_size, constants.num_negative_answers, constants.answer_embedding_dim, constants.answer_regularization_coeff, constants.answer_batch_size*constants.num_region_proposals, constants.num_regions_with_labels, constants.num_object_labels, constants.num_attribute_labels, constants.answer_obj_atr_loss_wt, resnet_feat_dim=constants.resnet_feat_dim, training=True) print 'Attaching optimizer...' optimizer = train.attach_optimizer( graph, constants.answer_lr) print 'Starting a session...' config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = 0.8 sess = tf.Session(config=config, graph=graph.tf_graph) print 'Creating initializer...' initializer = create_initializer( graph, sess, constants.answer_fine_tune_from) print 'Creating feed dict creator...' feed_dict_creator = train.create_feed_dict_creator( graph.plh, constants.num_negative_answers) print 'Creating dict of vars to be evaluated...' vars_to_eval_dict = { 'optimizer_op': optimizer.train_op, 'word_vectors': graph.word_vec_mgr.word_vectors, 'relevance_prob': graph.relevance_inference.answer_region_prob[0], 'per_region_answer_prob': graph.answer_inference.per_region_answer_prob[0], 'object_scores': graph.obj_atr_inference.object_scores, 'attribute_scores': graph.obj_atr_inference.attribute_scores, 'answer_scores': graph.answer_inference.answer_score[0], 'accuracy': graph.moving_average_accuracy, 'total_loss': graph.total_loss, # 'question_embed_concat': graph.question_embed_concat, # 'answer_embed_concat': graph.answers_embed_concat, # 'noun_embed': graph.noun_embed['positive_nouns'], # 'adjective_embed': graph.adjective_embed['positive_adjectives'], # 'assert': graph.answer_inference.assert_op, 'merged': graph.merged, } print 'Creating logger...' vars_to_save = graph.vars_to_save logger = train.log_mgr( graph, graph.vars_to_save, sess, constants.answer_log_every_n_iter, constants.answer_output_dir, constants.answer_model) print 'Start training...' fine_tune( constants.answer_fine_tune_from_iter, batch_generator, sess, initializer, vars_to_eval_dict, feed_dict_creator, logger)