Something went wrong on our end
train.py 13.75 KiB
import pdb
import os
import ujson
import numpy as np
import data.cropped_regions as cropped_regions
import tftools.data
from tftools import var_collect, placeholder_management
from object_attribute_classifier import inference
from word2vec.word_vector_management import word_vector_manager
import losses
import constants
import tensorflow as tf
class graph_creator():
def __init__(self,
batch_size,
tb_log_dir,
image_size,
num_object_labels,
num_attribute_labels,
regularization_coeff,
training=True):
self.im_h, self.im_w = image_size
self.batch_size = batch_size
self.num_object_labels = num_object_labels
self.num_attribute_labels = num_attribute_labels
self.regularization_coeff = regularization_coeff
self.tf_graph = tf.Graph()
with self.tf_graph.as_default():
self.create_placeholders()
self.word_vec_mgr = word_vector_manager()
self.obj_atr_inference = inference.ObjectAttributeInference(
self.plh['region_images'],
self.word_vec_mgr.object_label_vectors,
self.word_vec_mgr.attribute_label_vectors,
training)
self.add_losses()
self.add_accuracy_computation()
self.vars_to_save = tf.all_variables()
self.merged = tf.merge_all_summaries()
self.writer = tf.train.SummaryWriter(
tb_log_dir,
graph = self.tf_graph)
def create_placeholders(self):
self.plh = placeholder_management.PlaceholderManager()
self.plh.add_placeholder(
'region_images',
tf.float32,
shape=[None, self.im_h, self.im_w, 3])
self.plh.add_placeholder(
'object_labels',
tf.float32,
shape=[None, self.num_object_labels])
self.plh.add_placeholder(
'attribute_labels',
tf.float32,
shape=[None, self.num_attribute_labels])
def add_losses(self):
self.object_loss = losses.object_loss(
self.obj_atr_inference.object_scores,
self.plh['object_labels'])
self.attribute_loss = 100*losses.attribute_loss(
self.obj_atr_inference.attribute_scores,
self.plh['attribute_labels'],
self.batch_size)
self.regularization_loss = self.regularization()
self.total_loss = self.object_loss + \
self.attribute_loss + \
self.regularization_loss
total_loss_summary = tf.scalar_summary(
"loss_total",
self.total_loss)
object_loss_summary = tf.scalar_summary(
"loss_object",
self.object_loss)
attribute_loss_summary = tf.scalar_summary(
"loss_attribute",
self.attribute_loss)
regularization_loss_summary = tf.scalar_summary(
"loss_regularization",
self.regularization_loss)
def regularization(self):
vars_to_regularize = tf.get_collection('to_regularize')
loss = losses.regularization_loss(
vars_to_regularize,
self.regularization_coeff)
return loss
def add_accuracy_computation(self):
with tf.variable_scope('accuracy_graph'):
self.object_accuracy = self.add_object_accuracy_computation(
self.obj_atr_inference.object_scores,
self.plh['object_labels'])
self.attribute_accuracy = self.add_attribute_accuracy_computation(
self.obj_atr_inference.attribute_scores,
self.plh['attribute_labels'])
object_accuracy_summary = tf.scalar_summary(
"accuracy_object",
self.object_accuracy)
attribute_accuracy_summary = tf.scalar_summary(
"accuracy_attribute",
self.attribute_accuracy)
def add_object_accuracy_computation(self, scores, labels):
with tf.variable_scope('object_accuracy'):
correct_prediction = tf.equal(
tf.argmax(scores, 1),
tf.argmax(labels, 1),
name='correct_prediction')
object_accuracy = tf.reduce_mean(
tf.cast(correct_prediction, tf.float32),
name='accuracy')
return object_accuracy
def add_attribute_accuracy_computation(self, scores, labels):
with tf.variable_scope('object_accuracy'):
thresholded = tf.greater(
scores,
0.0,
name='thresholded')
correct_prediction = tf.equal(
thresholded,
tf.cast(labels, tf.bool),
name = 'correct_prediction')
attribute_accuracy = tf.reduce_mean(
tf.cast(correct_prediction, tf.float32),
name='accuracy')
return attribute_accuracy
def create_initializer(graph, sess, resnet_model):
class initializer():
def __init__(self):
with graph.tf_graph.as_default():
resnet_vars = graph.obj_atr_inference.resnet_vars
resnet_restorer = tf.train.Saver(resnet_vars)
resnet_restorer.restore(sess, resnet_model)
not_to_init = resnet_vars
all_vars = tf.all_variables()
other_vars = [var for var in all_vars
if var not in not_to_init]
var_collect.print_var_list(
other_vars,
'vars_to_init')
self.init = tf.initialize_variables(other_vars)
def initialize(self):
sess.run(self.init)
return initializer()
def create_batch_generator():
data_mgr = cropped_regions.data(
constants.image_dir,
constants.object_labels_json,
constants.attribute_labels_json,
constants.regions_json,
constants.image_size,
channels=3,
mean_image_filename=None)
index_generator = tftools.data.random(
constants.region_batch_size,
constants.region_num_samples,
constants.region_num_epochs,
constants.region_offset)
batch_generator = tftools.data.async_batch_generator(
data_mgr,
index_generator,
constants.region_queue_size)
return batch_generator
def create_feed_dict_creator(plh):
def feed_dict_creator(batch):
inputs = {
'region_images': batch['region_images'],
'object_labels': batch['object_labels'],
'attribute_labels': batch['attribute_labels']
}
return plh.get_feed_dict(inputs)
return feed_dict_creator
class attach_optimizer():
def __init__(self, graph, lr):
self.graph = graph
self.lr = lr
with graph.tf_graph.as_default():
resnet_vars = graph.obj_atr_inference.resnet_vars
all_trainable_vars = tf.trainable_variables()
self.not_to_train = []#[graph.word_vec_mgr.word_vectors] resnet_vars +
vars_to_train = [
var for var in all_trainable_vars
if var not in self.not_to_train]
var_collect.print_var_list(
vars_to_train,
'vars_to_train')
self.ops = dict()
self.add_adam_optimizer(
graph.total_loss,
vars_to_train,
'all_but_resnet')
self.add_adam_optimizer(
graph.total_loss,
resnet_vars,
'resnet')
self.train_op = self.group_all_train_ops()
def filter_out_vars_to_train(self, var_list):
return [var for var in var_list if var not in self.not_to_train]
def add_adam_optimizer(self, loss, var_list, name):
var_list = self.filter_out_vars_to_train(var_list)
if not var_list:
self.ops[name] = []
return
train_step = tf.train.AdamOptimizer(self.lr) \
.minimize(
loss,
var_list = var_list)
self.ops[name] = train_step
def group_all_train_ops(self):
train_op = tf.group()
for op in self.ops.values():
if op:
train_op = tf.group(train_op, op)
# resnet_bn_updates = self.graph.tf_graph.get_collection(
# 'resnet_update_ops')
# resnet_bn_updates_op = tf.group(*resnet_bn_updates)
# train_op = tf.group(train_op, resnet_bn_updates_op)
return train_op
class log_mgr():
def __init__(
self,
graph,
vars_to_save,
sess,
log_every_n_iter,
output_dir,
model_path):
self.graph = graph
self.vars_to_save = vars_to_save
self.sess = sess
self.log_every_n_iter = log_every_n_iter
self.output_dir = output_dir
self.model_path = model_path
self.model_saver = tf.train.Saver(
var_list = vars_to_save,
max_to_keep = 0)
self.loss_values = dict()
def log(self, iter, is_last=False, eval_vars_dict=None):
if eval_vars_dict:
self.graph.writer.add_summary(
eval_vars_dict['merged'],
iter)
print 'object'
print np.max(eval_vars_dict['object_prob'][0,:])
print np.min(eval_vars_dict['object_prob'][0,:])
print np.max(eval_vars_dict['object_scores'][0,:])
print np.min(eval_vars_dict['object_scores'][0,:])
print 'attribute'
print np.max(eval_vars_dict['attribute_prob'][0,:])
print np.min(eval_vars_dict['attribute_prob'][0,:])
print np.max(eval_vars_dict['attribute_scores'][0,:])
print np.min(eval_vars_dict['attribute_scores'][0,:])
print 'object_accuracy'
print eval_vars_dict['object_accuracy']
print 'attribute_accuracy'
print eval_vars_dict['attribute_accuracy']
self.loss_values[iter] = {
'total_loss': str(eval_vars_dict['total_loss']),
'object_loss': str(eval_vars_dict['object_loss']),
'attribute_loss': str(eval_vars_dict['attribute_loss'])}
if iter % self.log_every_n_iter==0 or is_last:
self.model_saver.save(
self.sess,
self.model_path,
global_step=iter)
loss_path = os.path.join(
self.output_dir,
'losses_' + str(iter) + '.json')
with open(loss_path, 'w') as outfile:
ujson.dump(
self.loss_values,
outfile,
sort_keys=True,
indent=4)
def train(
batch_generator,
sess,
initializer,
vars_to_eval_dict,
feed_dict_creator,
logger):
vars_to_eval_names = []
vars_to_eval = []
for var_name, var in vars_to_eval_dict.items():
vars_to_eval_names += [var_name]
vars_to_eval += [var]
with sess.as_default():
initializer.initialize()
iter = 0
for batch in batch_generator:
print '---'
print 'Iter: {}'.format(iter)
feed_dict = feed_dict_creator(batch)
eval_vars = sess.run(
vars_to_eval,
feed_dict = feed_dict)
eval_vars_dict = {
var_name: eval_var for var_name, eval_var in
zip(vars_to_eval_names, eval_vars)}
logger.log(iter, False, eval_vars_dict)
iter+=1
logger.log(iter-1, True, eval_vars_dict)
if __name__=='__main__':
print 'Creating batch generator...'
batch_generator = create_batch_generator()
print 'Creating computation graph...'
graph = graph_creator(
constants.region_batch_size,
constants.tb_log_dir,
constants.image_size,
constants.num_object_labels,
constants.num_attribute_labels,
constants.region_regularization_coeff)
print 'Attaching optimizer...'
optimizer = attach_optimizer(
graph,
constants.region_lr)
print 'Starting a session...'
sess = tf.Session(graph=graph.tf_graph)
print 'Creating initializer...'
initializer = create_initializer(
graph,
sess,
constants.resnet_ckpt)
print 'Creating feed dict creator...'
feed_dict_creator = create_feed_dict_creator(graph.plh)
print 'Creating dict of vars to be evaluated...'
vars_to_eval_dict = {
'object_prob': graph.obj_atr_inference.object_prob,
'object_scores': graph.obj_atr_inference.object_scores,
'attribute_prob': graph.obj_atr_inference.attribute_prob,
'attribute_scores': graph.obj_atr_inference.attribute_scores,
'object_accuracy': graph.object_accuracy,
'attribute_accuracy': graph.attribute_accuracy,
'total_loss': graph.total_loss,
'object_loss': graph.object_loss,
'attribute_loss': graph.attribute_loss,
'optimizer_op': optimizer.train_op,
'merged': graph.merged,
}
print 'Creating logger...'
vars_to_save = graph.vars_to_save
logger = log_mgr(
graph,
vars_to_save,
sess,
constants.region_log_every_n_iter,
constants.region_output_dir,
constants.region_model)
print 'Start training...'
train(
batch_generator,
sess,
initializer,
vars_to_eval_dict,
feed_dict_creator,
logger)