from tftools import layers
import tensorflow as tf
import pdb

class AnswerInference():
    def __init__(
            self,
            object_feat,
            attribute_feat,
            answer_region_scores,
            question_vert_concat,
            answers_vert_concat,
            num_answers,
            space_dim,
            keep_prob):
        
        self.batch_size = len(object_feat)
        self.object_feat = object_feat
        self.attribute_feat = attribute_feat
        self.answer_region_scores = answer_region_scores
        self.question_vert_concat = question_vert_concat
        self.answers_vert_concat = answers_vert_concat
        self.num_answers = num_answers
        self.keep_prob = keep_prob
        with tf.variable_scope('answer_graph'):
            
            self.question = [None]*self.batch_size
            for j in xrange(self.batch_size):
                self.question[j] = tf.reshape(
                    self.question_vert_concat[j],
                    [1, -1])
        

            self.answers = self.answers_vert_concat

            self.qa_proj = [None]*self.batch_size
            for j in xrange(self.batch_size):
                self.qa_proj[j] = self.project_question_answer(
                    self.question[j], 
                    self.answers[j], 
                    space_dim, 
                    'question_answer_projection',
                    None if j==0 else True)
            
            self.obj_atr_proj = [None]*self.batch_size
            for j in xrange(self.batch_size):
                self.obj_atr_proj[j] = self.project_object_attribute(
                    self.object_feat[j],
                    self.attribute_feat[j],
                    space_dim,
                    'object_attribute_projection',
                    None if j==0 else True)

            self.per_region_answer_scores = [None]*self.batch_size
            for j in xrange(self.batch_size):
                self.per_region_answer_scores[j] = tf.matmul(
                    self.obj_atr_proj[j],
                    tf.transpose(self.qa_proj[j]),
                    name='per_region_answer_scores')

            self.per_region_answer_prob = [None]*self.batch_size
            self.answer_score = [None]*self.batch_size
            for j in xrange(self.batch_size):
                self.per_region_answer_prob[j] = tf.nn.softmax(
                    self.per_region_answer_scores[j],
                    'per_region_answer_prob_softmax')
                
                answer_score_tmp = tf.mul(
                    self.per_region_answer_scores[j],
                    tf.transpose(self.answer_region_scores[j]))
            
                self.answer_score[j] = tf.reduce_sum(
                    answer_score_tmp,
                    0,
                    keep_dims=True)

    def concat_object_attribute(self, object_feat, attribute_feat):
        object_attribute = tf.concat(
            1,
            [object_feat, attribute_feat])

        return object_attribute


    def project_question_answer(self, question, answer, space_dim, scope_name, reuse):
        with tf.variable_scope(scope_name, reuse=reuse):
            question_replicated = tf.tile(
                question,
                [self.num_answers, 1],
                name='replicate_questions')
            qa = tf.concat(
                1,
                [question_replicated, answer])
                
            qa_proj1 = tf.nn.dropout(layers.full(qa, 1000, 'fc1'), self.keep_prob)
            qa_proj2 = tf.nn.dropout(layers.full(qa_proj1, space_dim, 'fc2'), self.keep_prob)

            # qa_proj1 = layers.full(qa, 1000, 'fc1')
            # qa_proj2 = layers.full(qa_proj1, space_dim, 'fc2')

        return qa_proj2
                
    def project_object_attribute(self, object_feat, attribute_feat, space_dim, scope_name, reuse):
        with tf.variable_scope(scope_name, reuse=reuse):
            obj_atr_feat = tf.concat(
                1,
                [object_feat, attribute_feat],
                name='object_attribute_concat')

            obj_atr_proj1 = tf.nn.dropout(layers.full(obj_atr_feat, 600, 'fc1'), self.keep_prob)
            obj_atr_proj2 = tf.nn.dropout(layers.full(obj_atr_proj1, space_dim, 'fc2'), self.keep_prob)

            # obj_atr_proj1 = layers.full(obj_atr_feat, 600, 'fc1')
            # obj_atr_proj2 = layers.full(obj_atr_proj1, space_dim, 'fc2')
        return obj_atr_proj2