Skip to content
Snippets Groups Projects
Commit f6fa768f authored by tgupta6's avatar tgupta6
Browse files

multilabel_margin_loss

parent 98f2b8b5
No related branches found
No related tags found
No related merge requests found
......@@ -57,6 +57,21 @@ def margin_loss(y, y_pred, margin):
keep_dims=True, name='correct_score')
return tf.reduce_mean(tf.maximum(0.0, y_pred + margin - correct_score))
def multilabel_margin_loss(y, y_pred, margin):
y_list = tf.unpack(y)
y_pred_list = tf.unpack(y):
loss = 0.0
for y_, y_pred_ in zip(y_list, y_pred_list):
partition = tf.dynamic_partition(y_pred_, y_, 2)
pos_labels_scores = tf.expand_dims(tf.transpose(partition[1]),1)
neg_labels_scores = partition[0]
margin_violation = tf.maximum(
0, neg-labels_scores + margin - pos_labels_scores)
loss += tf.reduce_mean(margin_violation)
loss /= len(y_list)
return loss
def mil_loss(scores, y, type='obj', epsilon=1e-5):
if type=='obj':
log_prob = tf.nn.log_softmax(scores)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment