Something went wrong on our end
fine_tune.py 5.07 KiB
# from word2vec.word_vector_management import word_vector_manager
# import object_attribute_classifier_cached_features.inference as feature_graph
# import region_relevance_network.inference as relevance_graph
# import answer_classifier_cached_features.inference as answer_graph
from tftools import var_collect, placeholder_management
import tftools.data
import losses
import constants
import tftools.var_collect as var_collect
import data.vqa
import answer_classifier_cached_features.train as train
import numpy as np
import pdb
import tensorflow as tf
def create_initializer(graph, sess, model):
class initializer():
def __init__(self):
with graph.tf_graph.as_default():
model_vars = graph.vars_to_save
model_restorer = tf.train.Saver(model_vars)
model_restorer.restore(sess, model)
not_to_init = model_vars
all_vars = tf.all_variables()
other_vars = [var for var in all_vars
if var not in not_to_init]
var_collect.print_var_list(
other_vars,
'vars_to_init')
self.init = tf.initialize_variables(other_vars)
def initialize(self):
sess.run(self.init)
return initializer()
def fine_tune(
fine_tune_from_iter,
batch_generator,
sess,
initializer,
vars_to_eval_dict,
feed_dict_creator,
logger):
vars_to_eval_names = []
vars_to_eval = []
for var_name, var in vars_to_eval_dict.items():
vars_to_eval_names += [var_name]
vars_to_eval += [var]
with sess.as_default():
initializer.initialize()
iter = fine_tune_from_iter + 1
for batch in batch_generator:
print '---'
print 'Iter: {}'.format(iter)
feed_dict = feed_dict_creator(batch)
eval_vars = sess.run(
vars_to_eval,
feed_dict = feed_dict)
eval_vars_dict = {
var_name: eval_var for var_name, eval_var in
zip(vars_to_eval_names, eval_vars)}
logger.log(iter, False, eval_vars_dict)
iter+=1
logger.log(iter-1, True, eval_vars_dict)
if __name__=='__main__':
print 'Creating batch generator...'
batch_generator = train.create_batch_generator()
print 'Creating computation graph...'
graph = train.graph_creator(
constants.tb_log_dir,
constants.answer_batch_size,
constants.image_size,
constants.num_negative_answers,
constants.answer_embedding_dim,
constants.answer_regularization_coeff,
constants.answer_batch_size*constants.num_region_proposals,
constants.num_regions_with_labels,
constants.num_object_labels,
constants.num_attribute_labels,
constants.answer_obj_atr_loss_wt,
resnet_feat_dim=constants.resnet_feat_dim,
training=True)
print 'Attaching optimizer...'
optimizer = train.attach_optimizer(
graph,
constants.answer_lr)
print 'Starting a session...'
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.8
sess = tf.Session(config=config, graph=graph.tf_graph)
print 'Creating initializer...'
initializer = create_initializer(
graph,
sess,
constants.answer_fine_tune_from)
print 'Creating feed dict creator...'
feed_dict_creator = train.create_feed_dict_creator(
graph.plh,
constants.num_negative_answers)
print 'Creating dict of vars to be evaluated...'
vars_to_eval_dict = {
'optimizer_op': optimizer.train_op,
'word_vectors': graph.word_vec_mgr.word_vectors,
'relevance_prob': graph.relevance_inference.answer_region_prob[0],
'per_region_answer_prob': graph.answer_inference.per_region_answer_prob[0],
'object_scores': graph.obj_atr_inference.object_scores,
'attribute_scores': graph.obj_atr_inference.attribute_scores,
'answer_scores': graph.answer_inference.answer_score[0],
'accuracy': graph.moving_average_accuracy,
'total_loss': graph.total_loss,
# 'question_embed_concat': graph.question_embed_concat,
# 'answer_embed_concat': graph.answers_embed_concat,
# 'noun_embed': graph.noun_embed['positive_nouns'],
# 'adjective_embed': graph.adjective_embed['positive_adjectives'],
# 'assert': graph.answer_inference.assert_op,
'merged': graph.merged,
}
print 'Creating logger...'
vars_to_save = graph.vars_to_save
logger = train.log_mgr(
graph,
graph.vars_to_save,
sess,
constants.answer_log_every_n_iter,
constants.answer_output_dir,
constants.answer_model)
print 'Start training...'
fine_tune(
constants.answer_fine_tune_from_iter,
batch_generator,
sess,
initializer,
vars_to_eval_dict,
feed_dict_creator,
logger)