Skip to content
Snippets Groups Projects
train.py 13.75 KiB
import pdb
import os
import ujson
import numpy as np

import data.cropped_regions as cropped_regions
import tftools.data
from tftools import var_collect, placeholder_management
from object_attribute_classifier import inference
from word2vec.word_vector_management import word_vector_manager
import losses
import constants

import tensorflow as tf


class graph_creator():
    def __init__(self, 
                 batch_size, 
                 tb_log_dir,
                 image_size,
                 num_object_labels,
                 num_attribute_labels,
                 regularization_coeff,
                 training=True):
        self.im_h, self.im_w = image_size
        self.batch_size = batch_size
        self.num_object_labels = num_object_labels
        self.num_attribute_labels = num_attribute_labels
        self.regularization_coeff = regularization_coeff
        self.tf_graph = tf.Graph()
        with self.tf_graph.as_default():
            self.create_placeholders()
            self.word_vec_mgr = word_vector_manager()
            self.obj_atr_inference = inference.ObjectAttributeInference(
                self.plh['region_images'],
                self.word_vec_mgr.object_label_vectors,
                self.word_vec_mgr.attribute_label_vectors,
                training)
            self.add_losses()
            self.add_accuracy_computation()
            self.vars_to_save = tf.all_variables()
            self.merged = tf.merge_all_summaries()
            self.writer = tf.train.SummaryWriter(
                tb_log_dir,
                graph = self.tf_graph)

    def create_placeholders(self):
        self.plh = placeholder_management.PlaceholderManager()
        
        self.plh.add_placeholder(
            'region_images',
            tf.float32,
            shape=[None, self.im_h, self.im_w, 3])

        self.plh.add_placeholder(
            'object_labels',
            tf.float32,
            shape=[None, self.num_object_labels])

        self.plh.add_placeholder(
            'attribute_labels',
            tf.float32,
            shape=[None, self.num_attribute_labels])        

    def add_losses(self):
        self.object_loss = losses.object_loss(
            self.obj_atr_inference.object_scores, 
            self.plh['object_labels'])

        self.attribute_loss = 100*losses.attribute_loss(
            self.obj_atr_inference.attribute_scores, 
            self.plh['attribute_labels'],
            self.batch_size)

        self.regularization_loss = self.regularization()
        
        self.total_loss = self.object_loss + \
                          self.attribute_loss + \
                          self.regularization_loss

        total_loss_summary = tf.scalar_summary(
            "loss_total", 
            self.total_loss)

        object_loss_summary = tf.scalar_summary(
            "loss_object", 
            self.object_loss)

        attribute_loss_summary = tf.scalar_summary(
            "loss_attribute", 
            self.attribute_loss)

        regularization_loss_summary = tf.scalar_summary(
            "loss_regularization", 
            self.regularization_loss)
                
    def regularization(self):
        vars_to_regularize = tf.get_collection('to_regularize')
        loss = losses.regularization_loss(
            vars_to_regularize,
            self.regularization_coeff)
        return loss

    def add_accuracy_computation(self):
        with tf.variable_scope('accuracy_graph'):
            self.object_accuracy = self.add_object_accuracy_computation(
                self.obj_atr_inference.object_scores,
                self.plh['object_labels'])
        
            self.attribute_accuracy = self.add_attribute_accuracy_computation(
                self.obj_atr_inference.attribute_scores,
                self.plh['attribute_labels'])

            object_accuracy_summary = tf.scalar_summary(
                "accuracy_object", 
                self.object_accuracy)

            attribute_accuracy_summary = tf.scalar_summary(
                "accuracy_attribute", 
                self.attribute_accuracy)
            
    def add_object_accuracy_computation(self, scores, labels):
        with tf.variable_scope('object_accuracy'):
            correct_prediction = tf.equal(
                tf.argmax(scores, 1), 
                tf.argmax(labels, 1), 
                name='correct_prediction')

            object_accuracy =  tf.reduce_mean(
                tf.cast(correct_prediction, tf.float32), 
                name='accuracy')

        return object_accuracy

    def add_attribute_accuracy_computation(self, scores, labels):
        with tf.variable_scope('object_accuracy'):
            thresholded = tf.greater(
                scores, 
                0.0, 
                name='thresholded')

            correct_prediction = tf.equal(
                thresholded,
                tf.cast(labels, tf.bool),
                name = 'correct_prediction')

            attribute_accuracy = tf.reduce_mean(
                tf.cast(correct_prediction, tf.float32), 
                name='accuracy')

        return attribute_accuracy


def create_initializer(graph, sess, resnet_model):
    class initializer():
        def __init__(self):
            with graph.tf_graph.as_default():
                resnet_vars = graph.obj_atr_inference.resnet_vars
                resnet_restorer = tf.train.Saver(resnet_vars)
                resnet_restorer.restore(sess, resnet_model)    
                not_to_init = resnet_vars
                all_vars = tf.all_variables()
                other_vars = [var for var in all_vars
                              if var not in not_to_init]
                var_collect.print_var_list(
                    other_vars,
                    'vars_to_init')
                self.init = tf.initialize_variables(other_vars)

        def initialize(self):
            sess.run(self.init)
    
    return initializer()


def create_batch_generator():
    data_mgr = cropped_regions.data(
        constants.image_dir,
        constants.object_labels_json,
        constants.attribute_labels_json,
        constants.regions_json,
        constants.image_size,
        channels=3,
        mean_image_filename=None)

    index_generator = tftools.data.random(
        constants.region_batch_size, 
        constants.region_num_samples, 
        constants.region_num_epochs, 
        constants.region_offset)
    
    batch_generator = tftools.data.async_batch_generator(
        data_mgr, 
        index_generator, 
        constants.region_queue_size)
    
    return batch_generator


def create_feed_dict_creator(plh):
    def feed_dict_creator(batch):
        inputs = {
            'region_images': batch['region_images'],
            'object_labels': batch['object_labels'],
            'attribute_labels': batch['attribute_labels']
        }
        return plh.get_feed_dict(inputs)

    return feed_dict_creator


class attach_optimizer():
    def __init__(self, graph, lr):
        self.graph = graph
        self.lr = lr
        with graph.tf_graph.as_default():
            resnet_vars = graph.obj_atr_inference.resnet_vars
            all_trainable_vars = tf.trainable_variables()
            self.not_to_train = []#[graph.word_vec_mgr.word_vectors]     resnet_vars + 
            vars_to_train = [
                var for var in all_trainable_vars
                if var not in self.not_to_train]
            var_collect.print_var_list(
                vars_to_train,
                'vars_to_train')
            self.ops = dict()

            self.add_adam_optimizer(
                graph.total_loss,
                vars_to_train,
                'all_but_resnet')

            self.add_adam_optimizer(
                graph.total_loss,
                resnet_vars,
                'resnet')

            self.train_op = self.group_all_train_ops()

    def filter_out_vars_to_train(self, var_list):
        return [var for var in var_list if var not in self.not_to_train]

    def add_adam_optimizer(self, loss, var_list, name):
        var_list = self.filter_out_vars_to_train(var_list)
        if not var_list:
            self.ops[name] = []
            return

        train_step = tf.train.AdamOptimizer(self.lr) \
                             .minimize(
                                 loss, 
                                 var_list = var_list)
        
        self.ops[name] = train_step

    def group_all_train_ops(self):
        train_op = tf.group()
        for op in self.ops.values():
            if op:
                train_op = tf.group(train_op, op)

        # resnet_bn_updates = self.graph.tf_graph.get_collection(
        #     'resnet_update_ops')
        # resnet_bn_updates_op = tf.group(*resnet_bn_updates)
        # train_op = tf.group(train_op, resnet_bn_updates_op)
        return train_op        

class log_mgr():
    def __init__(
            self, 
            graph,
            vars_to_save, 
            sess, 
            log_every_n_iter,
            output_dir,
            model_path):
        self.graph = graph
        self.vars_to_save = vars_to_save
        self.sess = sess
        self.log_every_n_iter = log_every_n_iter
        self.output_dir = output_dir
        self.model_path = model_path

        self.model_saver = tf.train.Saver(
            var_list = vars_to_save,
            max_to_keep = 0)

        self.loss_values = dict()

    def log(self, iter, is_last=False, eval_vars_dict=None):
        if eval_vars_dict:
            self.graph.writer.add_summary(
                eval_vars_dict['merged'], 
                iter)
            print 'object'
            print np.max(eval_vars_dict['object_prob'][0,:])
            print np.min(eval_vars_dict['object_prob'][0,:])
            print np.max(eval_vars_dict['object_scores'][0,:])
            print np.min(eval_vars_dict['object_scores'][0,:])

            print 'attribute'
            print np.max(eval_vars_dict['attribute_prob'][0,:])
            print np.min(eval_vars_dict['attribute_prob'][0,:])
            print np.max(eval_vars_dict['attribute_scores'][0,:])
            print np.min(eval_vars_dict['attribute_scores'][0,:])

            print 'object_accuracy'
            print eval_vars_dict['object_accuracy']

            print 'attribute_accuracy'
            print eval_vars_dict['attribute_accuracy']

            self.loss_values[iter] =  {
                'total_loss': str(eval_vars_dict['total_loss']),
                'object_loss': str(eval_vars_dict['object_loss']),
                'attribute_loss': str(eval_vars_dict['attribute_loss'])}

        if iter % self.log_every_n_iter==0 or is_last:
            self.model_saver.save(
                self.sess, 
                self.model_path, 
                global_step=iter)
            
            loss_path = os.path.join(
                self.output_dir,
                'losses_' + str(iter) + '.json')

            with open(loss_path, 'w') as outfile:
                ujson.dump(
                    self.loss_values, 
                    outfile, 
                    sort_keys=True, 
                    indent=4)
                    
        
def train(
        batch_generator, 
        sess, 
        initializer,
        vars_to_eval_dict,
        feed_dict_creator,
        logger):

    vars_to_eval_names = []
    vars_to_eval = []
    for var_name, var in vars_to_eval_dict.items():
        vars_to_eval_names += [var_name]
        vars_to_eval += [var]

    with sess.as_default():
        initializer.initialize()

        iter = 0
        for batch in batch_generator:
            print '---'
            print 'Iter: {}'.format(iter)
            feed_dict = feed_dict_creator(batch)
            eval_vars = sess.run(
                vars_to_eval,
                feed_dict = feed_dict)
            eval_vars_dict = {
                var_name: eval_var for var_name, eval_var in
                zip(vars_to_eval_names, eval_vars)}
            logger.log(iter, False, eval_vars_dict)
            iter+=1
        
        logger.log(iter-1, True, eval_vars_dict)


if __name__=='__main__':
    print 'Creating batch generator...'
    batch_generator = create_batch_generator()

    print 'Creating computation graph...'
    graph = graph_creator(
        constants.region_batch_size,
        constants.tb_log_dir,
        constants.image_size,
        constants.num_object_labels,
        constants.num_attribute_labels,
        constants.region_regularization_coeff)

    print 'Attaching optimizer...'
    optimizer = attach_optimizer(
        graph, 
        constants.region_lr)

    print 'Starting a session...'
    sess = tf.Session(graph=graph.tf_graph)

    print 'Creating initializer...'
    initializer = create_initializer(
        graph, 
        sess, 
        constants.resnet_ckpt)

    print 'Creating feed dict creator...'
    feed_dict_creator = create_feed_dict_creator(graph.plh)

    print 'Creating dict of vars to be evaluated...'
    vars_to_eval_dict = {
        'object_prob': graph.obj_atr_inference.object_prob,
        'object_scores': graph.obj_atr_inference.object_scores,
        'attribute_prob': graph.obj_atr_inference.attribute_prob,
        'attribute_scores': graph.obj_atr_inference.attribute_scores,
        'object_accuracy': graph.object_accuracy,
        'attribute_accuracy': graph.attribute_accuracy,
        'total_loss': graph.total_loss,
        'object_loss': graph.object_loss,
        'attribute_loss': graph.attribute_loss,        
        'optimizer_op': optimizer.train_op,
        'merged': graph.merged,
    }

    print 'Creating logger...'
    vars_to_save = graph.vars_to_save
    logger = log_mgr(
        graph,
        vars_to_save, 
        sess, 
        constants.region_log_every_n_iter,
        constants.region_output_dir,
        constants.region_model)

    print 'Start training...'
    train(
        batch_generator, 
        sess, 
        initializer,
        vars_to_eval_dict,
        feed_dict_creator,
        logger)