Skip to content
Snippets Groups Projects
inference.py 8.46 KiB
import pdb

import resnet.inference as resnet_inference
from tftools import var_collect, placeholder_management, layers
import constants
from word2vec.word_vector_management import word_vector_manager
import losses 

import tensorflow as tf


class ObjectAttributeInference():
    def __init__(
            self,
            image_regions, 
            object_label_vectors,
            attribute_label_vectors,
            training):

        self.image_regions = image_regions
        self.training = training
        self.avg_pool_feat = resnet_inference.inference(
            self.image_regions,
            self.training,
            num_classes=None)

        
        self.avg_pool_feat = layers.batch_norm(
            self.avg_pool_feat,
            tf.constant(self.training))

        self.resnet_vars = self.get_resnet_vars()

        self.object_embed = self.add_object_graph(self.avg_pool_feat)
        self.attribute_embed = self.add_attribute_graph(self.avg_pool_feat)

        self.object_label_embed = self.add_object_label_graph(object_label_vectors)
        self.attribute_label_embed = self.add_attribute_label_graph(attribute_label_vectors)

        with tf.variable_scope('object_score_graph'):
            self.object_scores = self.compute_cosine_similarity(
                self.object_embed,
                self.object_label_embed)

            self.object_scores_alpha = tf.get_variable(
                'object_alpha', 
                shape=[], 
                initializer=tf.constant_initializer(1.0))

            tf.add_to_collection(
                'to_regularize',
                self.object_scores_alpha)

            self.object_scores_bias = tf.get_variable(
                'object_beta', 
                shape=[], 
                initializer=tf.constant_initializer(0.0))

            self.object_scores  = \
                self.object_scores_alpha * self.object_scores + \
                self.object_scores_bias

            # out_dim = self.object_label_embed.get_shape().as_list()[0]
            # self.object_scores = layers.full(
            #     tf.nn.relu(self.object_embed),
            #     out_dim,
            #     'object_fc',
            #     func = None) 
            

            self.object_prob = tf.nn.softmax(
                self.object_scores,
                name = 'object_prob')

        with tf.variable_scope('attribute_score_graph'):
            self.attribute_scores = self.compute_cosine_similarity(
                self.attribute_embed,
                self.attribute_label_embed)

            self.attribute_scores_alpha = tf.get_variable(
                'attribute_alpha', 
                shape=[1,self.attribute_scores.get_shape().as_list()[1]], 
                initializer=tf.constant_initializer(1.0))

            self.attribute_scores_bias = tf.get_variable(
                'attribute_beta', 
                shape=[1,self.attribute_scores.get_shape().as_list()[1]], 
                initializer=tf.constant_initializer(0.0))

            self.attribute_scores = \
                self.attribute_scores_alpha * self.attribute_scores + \
                self.attribute_scores_bias

            self.attribute_prob = tf.sigmoid(
                self.attribute_scores,
                name = 'attribute_prob')
        
    def get_resnet_vars(self):
        vars_resnet = []
        for s in xrange(5):
            vars_resnet += var_collect.collect_scope('scale'+str(s+1))

        return vars_resnet

    def add_object_graph(self, input):        
        with tf.variable_scope('object_graph') as object_graph:
            with tf.variable_scope('fc1') as fc1:
                in_dim = input.get_shape().as_list()[-1]
                out_dim = in_dim/2
                fc1_out = layers.full(
                    input,
                    out_dim,
                    'fc',
                    func = None)
                fc1_out = layers.batch_norm(
                    fc1_out,
                    tf.constant(self.training))
                fc1_out = tf.nn.relu(fc1_out) 

            with tf.variable_scope('fc2') as fc2:
                in_dim = fc1_out.get_shape().as_list()[-1]
                out_dim = in_dim/2
                fc2_out = layers.full(
                    fc1_out,
                    out_dim,
                    'fc',
                    func = None)
                
        return fc2_out

    def add_attribute_graph(self, input):        
        with tf.variable_scope('attribute_graph') as attribute_graph:
            with tf.variable_scope('fc1') as fc1:
                in_dim = input.get_shape().as_list()[-1]
                out_dim = in_dim/2
                fc1_out = layers.full(
                    input,
                    out_dim,
                    'fc',
                    func = None)
                fc1_out = layers.batch_norm(
                    fc1_out,
                    tf.constant(self.training))
                fc1_out = tf.nn.relu(fc1_out) 

            with tf.variable_scope('fc2') as fc2:
                in_dim = fc1_out.get_shape().as_list()[-1]
                out_dim = in_dim/2
                fc2_out = layers.full(
                    fc1_out,
                    out_dim,
                    'fc',
                    func = None)
        
        return fc2_out
         
    def add_object_label_graph(self, input):
        with tf.variable_scope('object_label_graph'):
            out_dim = self.object_embed.get_shape().as_list()[-1]
            with tf.variable_scope('fc1') as fc1:
                in_dim = input.get_shape().as_list()[-1]
                fc1_out = layers.full(
                    input,
                    out_dim,
                    'fc',
                    func = None)
                fc1_out = layers.batch_norm(
                    fc1_out,
                    tf.constant(self.training))
                fc1_out = tf.nn.relu(fc1_out) 

            with tf.variable_scope('fc2') as fc2:
                in_dim = fc1_out.get_shape().as_list()[-1]
                fc2_out = layers.full(
                    fc1_out,
                    out_dim,
                    'fc',
                    func = None)
        
        return fc2_out

    def add_attribute_label_graph(self, input):
        with tf.variable_scope('attribute_label_graph'):
            out_dim = self.attribute_embed.get_shape().as_list()[-1]
            with tf.variable_scope('fc1') as fc1:
                in_dim = input.get_shape().as_list()[-1]
                fc1_out = layers.full(
                    input,
                    out_dim,
                    'fc',
                    func = None)
                fc1_out = layers.batch_norm(
                    fc1_out,
                    tf.constant(self.training))
                fc1_out = tf.nn.relu(fc1_out) 

            with tf.variable_scope('fc2') as fc2:
                in_dim = fc1_out.get_shape().as_list()[-1]
                fc2_out = layers.full(
                    fc1_out,
                    out_dim,
                    'fc',
                    func = None)
        
        return fc2_out
    
    def compute_cosine_similarity(self, feat1, feat2):
        feat1 = tf.nn.l2_normalize(feat1, 1)
        feat2 = tf.nn.l2_normalize(feat2, 1)
        return tf.matmul(feat1, tf.transpose(feat2), name='cosine_similarity')

    def compute_dot_product(self, feat1, feat2):
        return tf.matmul(feat1, tf.transpose(feat2), name='dot_product')

if __name__=='__main__':
    im_h, im_w = constants.image_size

    plh = placeholder_management.PlaceholderManager()
    plh.add_placeholder(
        name = 'image_regions', 
        dtype = tf.float32, 
        shape = [None, im_h, im_w, 3])
    plh.add_placeholder(
        name = 'object_labels',
        dtype = tf.float32,
        shape = [None, constants.num_object_labels])
    plh.add_placeholder(
        name = 'attribute_labels',
        dtype = tf.float32,
        shape = [None, constants.num_attribute_labels])
        
    word_vec_mgr = word_vector_manager()

    training = False
    obj_atr_inference = ObjectAttributeInference(
        plh['image_regions'],
        word_vec_mgr.object_label_vectors,
        word_vec_mgr.attribute_label_vectors,
        training)
    
    object_loss = losses.object_loss(
        obj_atr_inference.object_scores, 
        plh['object_labels'])
    
    attribute_loss = losses.attribute_loss(
        obj_atr_inference.attribute_scores, 
        plh['attribute_labels'])

    vars_to_regularize = tf.get_collection('to_regularize')
    var_collect.print_var_list(vars_to_regularize, 'to_regularize')