diff --git a/losses.py b/losses.py index f860f220b4c2fd7c705982fe0cce85fc38395d54..be74126cb9ec5b8231516b2f710c2f112a56f391 100644 --- a/losses.py +++ b/losses.py @@ -84,7 +84,7 @@ def multilabel_margin_loss_inner(y_,y_pred_,margin): return tf.reduce_mean(margin_violation) -def mil_loss(scores, y, type='obj', epsilon=1e-5): +def mil_loss_prob(scores, y, type='obj', epsilon=1e-5): if type=='obj': log_prob = tf.nn.log_softmax(scores) elif type=='atr': @@ -94,6 +94,16 @@ def mil_loss(scores, y, type='obj', epsilon=1e-5): loss = -tf.reduce_sum(max_region_scores)/tf.maximum(tf.reduce_sum(y),epsilon) return loss +def mil_loss(scores, y, type='obj', epsilon=1e-5): + if type=='obj': + log_prob = scores + elif type=='atr': + log_prob = tf.log(tf.maximum(epsilon, tf.nn.sigmoid(scores))) + + max_region_scores = tf.minimum(tf.reduce_max(log_prob*y,0)-1.0,0.0) + loss = -tf.reduce_mean(max_region_scores)#/tf.maximum(tf.reduce_sum(y),epsilon) + return loss + if __name__=='__main__': scores = tf.constant([[0.2, 0.3, 0.7],[0.8, 0.2, 0.9]])