diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index e05213536e64ae19bfb285f92aa9f215c359c827..316ecd713b71587c428fec0909716793090d2024 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -65,7 +65,7 @@ class SVMModel private[mllib] (
       intercept: Double) = {
     val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
     threshold match {
-      case Some(t) => if (margin < 0) 0.0 else 1.0
+      case Some(t) => if (margin < t) 0.0 else 1.0
       case None => margin
     }
   }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 77d6f04b32320e66860fab9a5f4913672282a65e..886c71dde3af75cd1c17b7381f7dcadb364d83c6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -69,6 +69,43 @@ class SVMSuite extends FunSuite with LocalSparkContext {
     assert(numOffPredictions < input.length / 5)
   }
 
+  test("SVM with threshold") {
+    val nPoints = 10000
+
+    // NOTE: Intercept should be small for generating equal 0s and 1s
+    val A = 0.01
+    val B = -1.5
+    val C = 1.0
+
+    val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42)
+
+    val testRDD = sc.parallelize(testData, 2)
+    testRDD.cache()
+
+    val svm = new SVMWithSGD().setIntercept(true)
+    svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
+
+    val model = svm.run(testRDD)
+
+    val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17)
+    val validationRDD  = sc.parallelize(validationData, 2)
+
+    // Test prediction on RDD.
+
+    var predictions = model.predict(validationRDD.map(_.features)).collect()
+    assert(predictions.count(_ == 0.0) != predictions.length)
+
+    // High threshold makes all the predictions 0.0
+    model.setThreshold(10000.0)
+    predictions = model.predict(validationRDD.map(_.features)).collect()
+    assert(predictions.count(_ == 0.0) == predictions.length)
+
+    // Low threshold makes all the predictions 1.0
+    model.setThreshold(-10000.0)
+    predictions = model.predict(validationRDD.map(_.features)).collect()
+    assert(predictions.count(_ == 1.0) == predictions.length)
+  }
+
   test("SVM using local random SGD") {
     val nPoints = 10000