diff --git a/losses.py b/losses.py
index 0972793dd4f961b5229d32b9d9f14544bfc8f62c..f860f220b4c2fd7c705982fe0cce85fc38395d54 100644
--- a/losses.py
+++ b/losses.py
@@ -57,21 +57,33 @@ def margin_loss(y, y_pred, margin):
                                   keep_dims=True, name='correct_score')
     return tf.reduce_mean(tf.maximum(0.0, y_pred + margin - correct_score))
 
-def multilabel_margin_loss(y, y_pred, margin):
-    y_list = tf.unpack(y)
-    y_pred_list = tf.unpack(y)
+
+def multilabel_margin_loss(y, y_pred, margin, num_samples):
+    y_list = tf.unpack(y, num_samples)
+    y_pred_list = tf.unpack(y_pred, num_samples)
     loss = 0.0
-    for y_, y_pred_ in zip(y_list, y_pred_list):
-        partition = tf.dynamic_partition(y_pred_, y_, 2)
-        pos_labels_scores = tf.expand_dims(tf.transpose(partition[1]),1)
-        neg_labels_scores = partition[0]
-        margin_violation = tf.maximum(
-            0, neg-labels_scores + margin - pos_labels_scores)
-        loss += tf.reduce_mean(margin_violation)
-
-    loss /= len(y_list)
+    for i in xrange(num_samples):
+        y_ = y_list[i]
+        y_pred_ = y_pred_list[i]
+        k = tf.reduce_sum(y_)
+        loss += tf.cond(
+            k > 0.5,
+            lambda: multilabel_margin_loss_inner(y_,y_pred_,margin),
+            lambda: tf.constant(0.0))
+    loss /= float(num_samples)
     return loss
+
     
+def multilabel_margin_loss_inner(y_,y_pred_,margin):
+    partition_ids = tf.cast(y_>0.5,tf.int32)
+    partition = tf.dynamic_partition(y_pred_, partition_ids, 2)
+    pos_labels_scores = tf.expand_dims(partition[1],1)
+    neg_labels_scores = partition[0]
+    margin_violation = tf.maximum(
+        0.0, neg_labels_scores + margin - pos_labels_scores)
+    return tf.reduce_mean(margin_violation)
+
+
 def mil_loss(scores, y, type='obj', epsilon=1e-5):
     if type=='obj':
         log_prob = tf.nn.log_softmax(scores)
@@ -82,6 +94,7 @@ def mil_loss(scores, y, type='obj', epsilon=1e-5):
     loss = -tf.reduce_sum(max_region_scores)/tf.maximum(tf.reduce_sum(y),epsilon)
     return loss
 
+
 if __name__=='__main__':
     scores = tf.constant([[0.2, 0.3, 0.7],[0.8, 0.2, 0.9]])
     labels = tf.constant([[1.0, 0.0, 0.0],[0.0, 1.0, 0.0]])