From 87a9dd898ff51fd110799edae087d59f6b714211 Mon Sep 17 00:00:00 2001
From: Reynold Xin <reynoldx@gmail.com>
Date: Tue, 23 Jul 2013 12:13:27 -0700
Subject: [PATCH] Made RegressionModel serializable and added unit tests to
 make sure predict methods would work.

---
 .../mllib/optimization/GradientDescent.scala  |  6 ++--
 .../spark/mllib/optimization/Updater.scala    |  6 ++--
 .../mllib/regression/LogisticRegression.scala |  6 +++-
 .../spark/mllib/regression/Regression.scala   |  2 +-
 .../mllib/regression/RidgeRegression.scala    |  5 ++-
 .../regression/LogisticRegressionSuite.scala  | 33 +++++++++++++++----
 6 files changed, 42 insertions(+), 16 deletions(-)

diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
index 4c996c0903..185a2a24f6 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
@@ -39,9 +39,9 @@ object GradientDescent {
    * @param miniBatchFraction - fraction of the input data set that should be used for
    *                            one iteration of SGD. Default value 1.0.
    *
-   * @return weights - Column matrix containing weights for every feature.
-   * @return stochasticLossHistory - Array containing the stochastic loss computed for 
-   *                                 every iteration.
+   * @return A tuple containing two elements. The first element is a column matrix containing
+   *         weights for every feature, and the second element is an array containing the stochastic
+   *         loss computed for every iteration.
    */
   def runMiniBatchSGD(
     data: RDD[(Double, Array[Double])],
diff --git a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
index b864fd4634..18cb5f3a95 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
@@ -23,13 +23,13 @@ abstract class Updater extends Serializable {
   /**
    * Compute an updated value for weights given the gradient, stepSize and iteration number.
    *
-   * @param weightsOld - Column matrix of size nx1 where n is the number of features.
+   * @param weightsOlds - Column matrix of size nx1 where n is the number of features.
    * @param gradient - Column matrix of size nx1 where n is the number of features.
    * @param stepSize - step size across iterations
    * @param iter - Iteration number
    *
-   * @return weightsNew - Column matrix containing updated weights
-   * @return reg_val - regularization value
+   * @return A tuple of 2 elements. The first element is a column matrix containing updated weights,
+   *         and the second element is the regularization value.
    */
   def compute(weightsOlds: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int):
       (DoubleMatrix, Double)
diff --git a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
index 711e205c39..4b22546017 100644
--- a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
@@ -36,8 +36,12 @@ class LogisticRegressionModel(
   private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*)
 
   override def predict(testData: spark.RDD[Array[Double]]) = {
+    // A small optimization to avoid serializing the entire model. Only the weightsMatrix
+    // and intercept is needed.
+    val localWeights = weightsMatrix
+    val localIntercept = intercept
     testData.map { x =>
-      val margin = new DoubleMatrix(1, x.length, x:_*).mmul(weightsMatrix).get(0) + this.intercept
+      val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept
       1.0/ (1.0 + math.exp(margin * -1))
     }
   }
diff --git a/mllib/src/main/scala/spark/mllib/regression/Regression.scala b/mllib/src/main/scala/spark/mllib/regression/Regression.scala
index 645204ddf3..b845ba1a89 100644
--- a/mllib/src/main/scala/spark/mllib/regression/Regression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/Regression.scala
@@ -19,7 +19,7 @@ package spark.mllib.regression
 
 import spark.RDD
 
-trait RegressionModel {
+trait RegressionModel extends Serializable {
   /**
    * Predict values for the given data set using the model trained.
    *
diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
index f724edd732..6ba141e8fb 100644
--- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
@@ -37,8 +37,11 @@ class RidgeRegressionModel(
   extends RegressionModel {
 
   override def predict(testData: RDD[Array[Double]]): RDD[Double] = {
+    // A small optimization to avoid serializing the entire model.
+    val localIntercept = this.intercept
+    val localWeights = this.weights
     testData.map { x =>
-      (new DoubleMatrix(1, x.length, x:_*).mmul(this.weights)).get(0) + this.intercept
+      (new DoubleMatrix(1, x.length, x:_*).mmul(localWeights)).get(0) + localIntercept
     }
   }
 
diff --git a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
index 47191d9a5a..6a8098b59d 100644
--- a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
@@ -23,7 +23,6 @@ import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
 import spark.SparkContext
-import spark.SparkContext._
 
 
 class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
@@ -51,15 +50,24 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
 
     // y <- A + B*x + rLogis()
     // y <- as.numeric(y > 0)
-    val y = (0 until nPoints).map { i =>
+    val y: Seq[Double] = (0 until nPoints).map { i =>
       val yVal = offset + scale * x1(i) + rLogis(i)
       if (yVal > 0) 1.0 else 0.0
     }
 
-    val testData = (0 until nPoints).map(i => (y(i).toDouble, Array(x1(i))))
+    val testData = (0 until nPoints).map(i => (y(i), Array(x1(i))))
     testData
   }
 
+  def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) {
+    val offPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
+      // A prediction is off if the prediction is more than 0.5 away from expected value.
+      math.abs(prediction - expected) > 0.5
+    }.size
+    // At least 80% of the predictions should be on.
+    assert(offPredictions < input.length / 5)
+  }
+
   // Test if we can correctly learn A, B where Y = logistic(A + B*X)
   test("logistic regression") {
     val nPoints = 10000
@@ -70,14 +78,20 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
 
     val testRDD = sc.parallelize(testData, 2)
     testRDD.cache()
-    val lr = new LogisticRegression().setStepSize(10.0)
-                                     .setNumIterations(20)
+    val lr = new LogisticRegression().setStepSize(10.0).setNumIterations(20)
 
     val model = lr.train(testRDD)
 
+    // Test the weights
     val weight0 = model.weights(0)
     assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
     assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
+
+    // Test prediction on RDD.
+    validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData)
+
+    // Test prediction on Array.
+    validatePrediction(testData.map(row => model.predict(row._2)), testData)
   }
 
   test("logistic regression with initial weights") {
@@ -94,13 +108,18 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
     testRDD.cache()
 
     // Use half as many iterations as the previous test.
-    val lr = new LogisticRegression().setStepSize(10.0)
-                                     .setNumIterations(10)
+    val lr = new LogisticRegression().setStepSize(10.0).setNumIterations(10)
 
     val model = lr.train(testRDD, initialWeights)
 
     val weight0 = model.weights(0)
     assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
     assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
+
+    // Test prediction on RDD.
+    validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData)
+
+    // Test prediction on Array.
+    validatePrediction(testData.map(row => model.predict(row._2)), testData)
   }
 }
-- 
GitLab