diff --git a/losses.py b/losses.py index 0972793dd4f961b5229d32b9d9f14544bfc8f62c..f860f220b4c2fd7c705982fe0cce85fc38395d54 100644 --- a/losses.py +++ b/losses.py @@ -57,21 +57,33 @@ def margin_loss(y, y_pred, margin): keep_dims=True, name='correct_score') return tf.reduce_mean(tf.maximum(0.0, y_pred + margin - correct_score)) -def multilabel_margin_loss(y, y_pred, margin): - y_list = tf.unpack(y) - y_pred_list = tf.unpack(y) + +def multilabel_margin_loss(y, y_pred, margin, num_samples): + y_list = tf.unpack(y, num_samples) + y_pred_list = tf.unpack(y_pred, num_samples) loss = 0.0 - for y_, y_pred_ in zip(y_list, y_pred_list): - partition = tf.dynamic_partition(y_pred_, y_, 2) - pos_labels_scores = tf.expand_dims(tf.transpose(partition[1]),1) - neg_labels_scores = partition[0] - margin_violation = tf.maximum( - 0, neg-labels_scores + margin - pos_labels_scores) - loss += tf.reduce_mean(margin_violation) - - loss /= len(y_list) + for i in xrange(num_samples): + y_ = y_list[i] + y_pred_ = y_pred_list[i] + k = tf.reduce_sum(y_) + loss += tf.cond( + k > 0.5, + lambda: multilabel_margin_loss_inner(y_,y_pred_,margin), + lambda: tf.constant(0.0)) + loss /= float(num_samples) return loss + +def multilabel_margin_loss_inner(y_,y_pred_,margin): + partition_ids = tf.cast(y_>0.5,tf.int32) + partition = tf.dynamic_partition(y_pred_, partition_ids, 2) + pos_labels_scores = tf.expand_dims(partition[1],1) + neg_labels_scores = partition[0] + margin_violation = tf.maximum( + 0.0, neg_labels_scores + margin - pos_labels_scores) + return tf.reduce_mean(margin_violation) + + def mil_loss(scores, y, type='obj', epsilon=1e-5): if type=='obj': log_prob = tf.nn.log_softmax(scores) @@ -82,6 +94,7 @@ def mil_loss(scores, y, type='obj', epsilon=1e-5): loss = -tf.reduce_sum(max_region_scores)/tf.maximum(tf.reduce_sum(y),epsilon) return loss + if __name__=='__main__': scores = tf.constant([[0.2, 0.3, 0.7],[0.8, 0.2, 0.9]]) labels = tf.constant([[1.0, 0.0, 0.0],[0.0, 1.0, 0.0]])