diff --git a/losses.py b/losses.py
index 3ba65b80f5ece733bb3b3b5515179b50cc3eedaa..487181c7cb7f6499019c2717b6b445e45a85e154 100644
--- a/losses.py
+++ b/losses.py
@@ -57,6 +57,21 @@ def margin_loss(y, y_pred, margin):
                                   keep_dims=True, name='correct_score')
     return tf.reduce_mean(tf.maximum(0.0, y_pred + margin - correct_score))
 
+def multilabel_margin_loss(y, y_pred, margin):
+    y_list = tf.unpack(y)
+    y_pred_list = tf.unpack(y):
+    loss = 0.0
+    for y_, y_pred_ in zip(y_list, y_pred_list):
+        partition = tf.dynamic_partition(y_pred_, y_, 2)
+        pos_labels_scores = tf.expand_dims(tf.transpose(partition[1]),1)
+        neg_labels_scores = partition[0]
+        margin_violation = tf.maximum(
+            0, neg-labels_scores + margin - pos_labels_scores)
+        loss += tf.reduce_mean(margin_violation)
+
+    loss /= len(y_list)
+    return loss
+    
 def mil_loss(scores, y, type='obj', epsilon=1e-5):
     if type=='obj':
         log_prob = tf.nn.log_softmax(scores)