diff --git a/losses.py b/losses.py index 3ba65b80f5ece733bb3b3b5515179b50cc3eedaa..487181c7cb7f6499019c2717b6b445e45a85e154 100644 --- a/losses.py +++ b/losses.py @@ -57,6 +57,21 @@ def margin_loss(y, y_pred, margin): keep_dims=True, name='correct_score') return tf.reduce_mean(tf.maximum(0.0, y_pred + margin - correct_score)) +def multilabel_margin_loss(y, y_pred, margin): + y_list = tf.unpack(y) + y_pred_list = tf.unpack(y): + loss = 0.0 + for y_, y_pred_ in zip(y_list, y_pred_list): + partition = tf.dynamic_partition(y_pred_, y_, 2) + pos_labels_scores = tf.expand_dims(tf.transpose(partition[1]),1) + neg_labels_scores = partition[0] + margin_violation = tf.maximum( + 0, neg-labels_scores + margin - pos_labels_scores) + loss += tf.reduce_mean(margin_violation) + + loss /= len(y_list) + return loss + def mil_loss(scores, y, type='obj', epsilon=1e-5): if type=='obj': log_prob = tf.nn.log_softmax(scores)