diff --git a/answer_classifier_cached_features/inference.py b/answer_classifier_cached_features/inference.py index 11082891125a6565fa6628dfa7a141ecc538a203..34b3e1dc47980822137e15d41e32a0bff717faab 100644 --- a/answer_classifier_cached_features/inference.py +++ b/answer_classifier_cached_features/inference.py @@ -8,6 +8,8 @@ class AnswerInference(): self, object_feat, attribute_feat, + obj_detector_scores, + atr_detector_scores, answer_region_scores, question_vert_concat, answers_vert_concat, @@ -44,11 +46,11 @@ class AnswerInference(): noun_embed.append(self.noun_embed[key1][j]) adjective_embed.append(self.adjective_embed[key2][j]) - self.selected_noun_adjective[j] = self.inner_product_selection( - self.object_feat[j], - self.attribute_feat[j], - noun_embed, - adjective_embed) + # self.selected_noun_adjective[j] = self.inner_product_selection( + # self.object_feat[j], + # self.attribute_feat[j], + # noun_embed, + # adjective_embed) # self.per_region_answer_scores = [None]*self.batch_size obj_atr_qa_feat = [None]*self.batch_size @@ -76,9 +78,31 @@ class AnswerInference(): a_feat, [self.num_regions, 1, 1]) + # 100 x 1000 + obj_det_feat = tf.expand_dims( + self.obj_detector_scores[j], + 1) + + # 100 x 18 x 1000 + obj_det_feat = tf.tile( + self.obj_det_feat, + [1, self.num_answers, 1]) + + # 100 x 1000 + atr_det_feat = tf.expand_dims( + self.atr_detector_scores[j], + 1) + + # 100 x 18 x 1000 + atr_det_feat = tf.tile( + self.atr_det_feat, + [1, self.num_answers, 1]) + obj_atr_qa_feat[j] = tf.concat( 2, - [self.selected_noun_adjective[j], q_feat, a_feat]) + [obj_det_feat, atr_det_feat, q_feat, a_feat]) + + # [self.selected_noun_adjective[j], q_feat, a_feat]) # obj_atr_qa_feat[j] = tf.expand_dims( # obj_atr_qa_feat[j], diff --git a/answer_classifier_cached_features/train.py b/answer_classifier_cached_features/train.py index 7d485fdb351bb0e69c2fc617b5988b17f8f8b425..294080382d739a612d6453b6d2d03591a1fed495 100644 --- a/answer_classifier_cached_features/train.py +++ b/answer_classifier_cached_features/train.py @@ -71,9 +71,19 @@ class graph_creator(): self.split_obj_atr_inference_output() self.object_feat = self.object_embed_with_answers self.attribute_feat = self.attribute_embed_with_answers + self.obj_detector_scores = self.object_scores_with_answers + self.atr_detector_scores = self.attribute_scores_with_answers else: self.object_feat = self.obj_atr_inference.object_embed self.attribute_feat = self.obj_atr_inference.attribute_embed + self.obj_detector_scores = tf.split( + 0, + self.batch_size, + self.obj_atr_inference.object_scores) + self.atr_detector_scores = tf.split( + 0, + self.batch_size, + self.obj_atr_inference.attribute_scores) self.object_feat = tf.split( 0, @@ -105,6 +115,8 @@ class graph_creator(): self.answer_inference = answer_graph.AnswerInference( self.object_feat, self.attribute_feat, + self.obj_detector_scores, + self.atr_detector_scores, self.relevance_inference.answer_region_prob, self.question_embed_concat, self.answers_embed_concat,