from tftools import layers
import tensorflow as tf
import constants
import pdb

class AnswerInference():
    def __init__(
            self,
            object_feat,
            attribute_feat,
            obj_detector_scores,
            atr_detector_scores,
            answer_region_scores,
            question_vert_concat,
            answers_vert_concat,
            noun_embed,
            adjective_embed,
            num_answers,
            yes_no_feat,
            keep_prob,
            is_training=True):
        
        self.batch_size = len(object_feat)
        self.object_feat = object_feat
        self.attribute_feat = attribute_feat
        self.obj_detector_scores = obj_detector_scores
        self.atr_detector_scores = atr_detector_scores
        self.num_regions = constants.num_region_proposals
        self.answer_region_scores = answer_region_scores
        self.question_vert_concat = question_vert_concat
        self.answers_vert_concat = answers_vert_concat
        self.noun_embed = noun_embed
        self.adjective_embed = adjective_embed
        self.num_answers = num_answers
        self.yes_no_feat = yes_no_feat
        self.keep_prob = keep_prob
        self.is_training = is_training
        self.ordered_noun_keys = ['positive_nouns']
        self.ordered_adjective_keys = ['positive_adjectives']
        for i in xrange(self.num_answers-1):
            self.ordered_noun_keys.append('negative_nouns_' + str(i))
            self.ordered_adjective_keys.append('negative_adjectives_' + str(i))

        with tf.variable_scope('answer_graph'):
            self.selected_noun_adjective = [None]*self.batch_size
            for j in xrange(self.batch_size):
                noun_embed = []
                adjective_embed = []
                for key1, key2 in zip(self.ordered_noun_keys,self.ordered_adjective_keys):
                    noun_embed.append(self.noun_embed[key1][j])
                    adjective_embed.append(self.adjective_embed[key2][j])

                # 100 x 18 x 2
                self.selected_noun_adjective[j] = self.inner_product_selection(
                    self.object_feat[j], 
                    self.attribute_feat[j],
                    noun_embed,
                    adjective_embed)
            
#            self.per_region_answer_scores = [None]*self.batch_size
            obj_atr_qa_feat = [None]*self.batch_size
            q_a_feat_list = [None]*self.batch_size
            obj_atr_det_list = [None]*self.batch_size
            for j in xrange(self.batch_size):
                if j==0:
                    reuse_vars = False
                else:
                    reuse_vars = True

                q_feat = tf.reshape(
                    self.question_vert_concat[j],
                    [1, -1])
        
                q_feat = tf.expand_dims(q_feat,0)
        
                q_feat = tf.tile(
                    q_feat,
                    [self.num_regions, self.num_answers, 1])

                a_feat = tf.expand_dims(
                    self.answers_vert_concat[j],
                    0)

                a_feat = tf.tile(
                    a_feat,
                    [self.num_regions, 1, 1])

                # 100 x 1000
                obj_det_feat = tf.expand_dims(
                    self.obj_detector_scores[j],
                    1)

                # 100 x 18 x 1000
                obj_det_feat = tf.tile(
                    obj_det_feat,
                    [1, self.num_answers, 1])

                # 100 x 1000
                atr_det_feat = tf.expand_dims(
                    self.atr_detector_scores[j],
                    1)

                # 100 x 18 x 1000
                atr_det_feat = tf.tile(
                    atr_det_feat,
                    [1, self.num_answers, 1])

                # 1 x 18 x 2
                yes_no_feat_ = tf.expand_dims(
                    self.yes_no_feat[j],
                    0)

                yes_no_feat_ = tf.tile(
                    yes_no_feat_,
                    [self.num_regions, 1, 1])

                
                obj_atr_det_list[j] = tf.concat(
                    2,
                    [self.selected_noun_adjective[j], obj_det_feat, atr_det_feat])

                q_a_feat_list[j] = tf.concat(
                    2,
                    [yes_no_feat_, q_feat, a_feat])


                # obj_atr_qa_feat[j] = tf.expand_dims(
                #     obj_atr_qa_feat[j],
                #     0)

            self.obj_atr_det_packed = tf.pack(obj_atr_det_list)
            self.q_a_feat_packed = tf.pack(q_a_feat_list)

            self.obj_atr_det_conv_bn = self.conv_bn(
                self.obj_atr_det_packed,
                2500,
                'obj_atr_det_conv_bn')

            self.q_a_feat_conv_bn = self.conv_bn(
                self.q_a_feat_packed,
                2500,
                'q_a_feat_conv_bn')

            self.obj_atr_qa_feat = tf.nn.relu(
                self.obj_atr_det_conv_bn + self.q_a_feat_conv_bn)
            
            #obj_atr_qa_feat = tf.pack(obj_atr_qa_feat)
            #print obj_atr_qa_feat.get_shape()

            # self.per_region_answer_scores = layers.conv2d(
            #     obj_atr_qa_feat,
            #     1,
            #     2500,
            #     'per_region_ans_score_conv_1',
            #     func = None)
            
            self.per_region_answer_scores = tf.nn.relu(
                layers.batch_norm(
                    self.per_region_answer_scores,
                    tf.constant(self.is_training)))                

            self.per_region_answer_scores = layers.conv2d(
                self.per_region_answer_scores,
                1,
                1,
                'per_region_ans_score_conv_2',
                func = None)
                
            print self.per_region_answer_scores.get_shape()
            self.per_region_answer_scores = tf.squeeze(
                self.per_region_answer_scores,
                [3])
            
            self.per_region_answer_scores = tf.unpack(self.per_region_answer_scores)

            self.per_region_answer_prob = [None]*self.batch_size
            self.answer_score = [None]*self.batch_size
            for j in xrange(self.batch_size):
                self.per_region_answer_prob[j] = tf.nn.softmax(
                    self.per_region_answer_scores[j],
                    'per_region_answer_prob_softmax')
                
                answer_score_tmp = tf.mul(
                    self.per_region_answer_scores[j],
                    tf.transpose(self.answer_region_scores[j]))
            
                self.answer_score[j] = tf.reduce_sum(
                    answer_score_tmp,
                    0,
                    keep_dims=True)
        
    def inner_product_selection(self, obj_feat, atr_feat, noun_embed, adjective_embed):
        feats = []
        for k in xrange(18):
            scores = tf.matmul(obj_feat, tf.transpose(noun_embed[k]))
            scores1 = tf.reduce_max(scores,1,keep_dims=True)
            # scores1 = tf.nn.softmax(scores)
            # feat1 = tf.matmul(scores1, noun_embed[k])
            
            scores = tf.matmul(atr_feat, tf.transpose(adjective_embed[k]))
            scores2 = tf.reduce_max(scores,1,keep_dims=True)
            # scores2 = tf.nn.softmax(scores)
            # feat2 = tf.matmul(scores2, adjective_embed[k])

            # scores1_ = tf.reduce_sum(obj_feat*feat1,1,keep_dims=True)
            # scores2_ = tf.reduce_sum(atr_feat*feat2,1,keep_dims=True)
            # scores1_ = tf.matmul(obj_feat, tf.transpose(feat1))
            # scores2_ = tf.matmul(atr_feat, tf.transpose(feat2))

            #feat = tf.concat(1, [feat1, feat2, scores1_, scores2_])

            # 100 x 2
            feat = tf.concat(1, [scores1, scores2])

            feats.append(feat)

        print 'feat {}'.format(feat.get_shape())
        #print 'scores1_ {}'.format(scores1_.get_shape())
        # 100 x 18 x 2
        feats = tf.transpose(tf.pack(feats), [1,0,2])

        return feats

    def conv_bn(feat,out_dim,name):
        conv_feat = layers.conv2d(
            feat,
            1,
            out_dim,
            name,
            func = None)
            
        bn_conv_feat = layers.batch_norm(
            conv_feat,
            tf.constant(self.is_training))

        return bn_conv_feat
    
    # def elementwise_product(self, obj_feat, atr_feat, ques_feat, ans_feat):
    #     tiled_ques = tf.tile(tf.reshape(ques_feat,[1, -1]),[self.num_answers,1])
    #     qa_feat = tf.concat(
    #         1,
    #         [tiled_ques, ans_feat])
    #     qa_feat = tf.tile(qa_feat, [1,2]) 
        
    #     obj_atr_feat = tf.concat(
    #         1,
    #         [obj_feat, atr_feat])
    #     obj_atr_feat = tf.tile(obj_atr_feat, [1,5])
    #     obj_atr_feat = tf.expand_dims(obj_atr_feat,1)
        
    #     feat = obj_atr_feat*qa_feat
        
    #     return feat
                
            

    # def concat_object_attribute(self, object_feat, attribute_feat):
    #     object_attribute = tf.concat(
    #         1,
    #         [object_feat, attribute_feat])

    #     return object_attribute


    # def project_question_answer(self, question, answer, space_dim, scope_name, reuse):
    #     with tf.variable_scope(scope_name, reuse=reuse):
    #         question_replicated = tf.tile(
    #             question,
    #             [self.num_answers, 1],
    #             name='replicate_questions')
    #         qa = tf.concat(
    #             1,
    #             [question_replicated, answer])
                
    #         qa_proj1 = tf.nn.dropout(layers.full(qa, 1000, 'fc1'), self.keep_prob)
    #         qa_proj2 = tf.nn.dropout(layers.full(qa_proj1, space_dim, 'fc2'), self.keep_prob)

    #         # qa_proj1 = layers.full(qa, 1000, 'fc1')
    #         # qa_proj2 = layers.full(qa_proj1, space_dim, 'fc2')

    #     return qa_proj2
                
    # def project_object_attribute(self, object_feat, attribute_feat, space_dim, scope_name, reuse):
    #     with tf.variable_scope(scope_name, reuse=reuse):
    #         obj_atr_feat = tf.concat(
    #             1,
    #             [object_feat, attribute_feat],
    #             name='object_attribute_concat')

    #         obj_atr_proj1 = tf.nn.dropout(layers.full(obj_atr_feat, 600, 'fc1'), self.keep_prob)
    #         obj_atr_proj2 = tf.nn.dropout(layers.full(obj_atr_proj1, space_dim, 'fc2'), self.keep_prob)

    #         # obj_atr_proj1 = layers.full(obj_atr_feat, 600, 'fc1')
    #         # obj_atr_proj2 = layers.full(obj_atr_proj1, space_dim, 'fc2')
    #     return obj_atr_proj2