From d1e487473fd509f28daf28dcda856f3c2f1194ec Mon Sep 17 00:00:00 2001
From: Andrew Tulloch <andrew@tullo.ch>
Date: Tue, 13 May 2014 17:31:27 -0700
Subject: [PATCH] SPARK-1791 - SVM implementation does not use threshold
 parameter

Summary:
https://issues.apache.org/jira/browse/SPARK-1791

Simple fix, and backward compatible, since

- anyone who set the threshold was getting completely wrong answers.
- anyone who did not set the threshold had the default 0.0 value for the threshold anyway.

Test Plan:
Unit test added that is verified to fail under the old implementation,
and pass under the new implementation.

Reviewers:

CC:

Author: Andrew Tulloch <andrew@tullo.ch>

Closes #725 from ajtulloch/SPARK-1791-SVM and squashes the following commits:

770f55d [Andrew Tulloch] SPARK-1791 - SVM implementation does not use threshold parameter
---
 .../spark/mllib/classification/SVM.scala      |  2 +-
 .../spark/mllib/classification/SVMSuite.scala | 37 +++++++++++++++++++
 2 files changed, 38 insertions(+), 1 deletion(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index e05213536e..316ecd713b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -65,7 +65,7 @@ class SVMModel private[mllib] (
       intercept: Double) = {
     val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
     threshold match {
-      case Some(t) => if (margin < 0) 0.0 else 1.0
+      case Some(t) => if (margin < t) 0.0 else 1.0
       case None => margin
     }
   }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 77d6f04b32..886c71dde3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -69,6 +69,43 @@ class SVMSuite extends FunSuite with LocalSparkContext {
     assert(numOffPredictions < input.length / 5)
   }
 
+  test("SVM with threshold") {
+    val nPoints = 10000
+
+    // NOTE: Intercept should be small for generating equal 0s and 1s
+    val A = 0.01
+    val B = -1.5
+    val C = 1.0
+
+    val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42)
+
+    val testRDD = sc.parallelize(testData, 2)
+    testRDD.cache()
+
+    val svm = new SVMWithSGD().setIntercept(true)
+    svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
+
+    val model = svm.run(testRDD)
+
+    val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17)
+    val validationRDD  = sc.parallelize(validationData, 2)
+
+    // Test prediction on RDD.
+
+    var predictions = model.predict(validationRDD.map(_.features)).collect()
+    assert(predictions.count(_ == 0.0) != predictions.length)
+
+    // High threshold makes all the predictions 0.0
+    model.setThreshold(10000.0)
+    predictions = model.predict(validationRDD.map(_.features)).collect()
+    assert(predictions.count(_ == 0.0) == predictions.length)
+
+    // Low threshold makes all the predictions 1.0
+    model.setThreshold(-10000.0)
+    predictions = model.predict(validationRDD.map(_.features)).collect()
+    assert(predictions.count(_ == 1.0) == predictions.length)
+  }
+
   test("SVM using local random SGD") {
     val nPoints = 10000
 
-- 
GitLab