From 659329f9ee51ca8ae6232e07c45b5d9144d49667 Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Tue, 3 Feb 2015 00:14:43 -0800
Subject: [PATCH] [minor] update streaming linear algorithms

Author: Xiangrui Meng <meng@databricks.com>

Closes #4329 from mengxr/streaming-lr and squashes the following commits:

78731e1 [Xiangrui Meng] update streaming linear algorithms
---
 .../StreamingLogisticRegressionWithSGD.scala  |  3 +-
 .../regression/StreamingLinearAlgorithm.scala | 41 ++++++++++---------
 .../StreamingLinearRegressionWithSGD.scala    |  2 +-
 3 files changed, 24 insertions(+), 22 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
index eabd2162e2..6a3893d0e4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
@@ -88,8 +88,7 @@ class StreamingLogisticRegressionWithSGD private[mllib] (
 
   /** Set the initial weights. Default: [0.0, 0.0]. */
   def setInitialWeights(initialWeights: Vector): this.type = {
-    this.model = Option(algorithm.createModel(initialWeights, 0.0))
+    this.model = Some(algorithm.createModel(initialWeights, 0.0))
     this
   }
-
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
index 39a0dee931..44a8dbb994 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
@@ -21,7 +21,7 @@ import scala.reflect.ClassTag
 
 import org.apache.spark.Logging
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.streaming.dstream.DStream
 
 /**
@@ -58,7 +58,7 @@ abstract class StreamingLinearAlgorithm[
     A <: GeneralizedLinearAlgorithm[M]] extends Logging {
 
   /** The model to be updated and used for prediction. */
-  protected var model: Option[M] = null
+  protected var model: Option[M] = None
 
   /** The algorithm to use for updating. */
   protected val algorithm: A
@@ -77,18 +77,25 @@ abstract class StreamingLinearAlgorithm[
    * @param data DStream containing labeled data
    */
   def trainOn(data: DStream[LabeledPoint]) {
-    if (Option(model) == None) {
-      logError("Model must be initialized before starting training")
-      throw new IllegalArgumentException
+    if (model.isEmpty) {
+      throw new IllegalArgumentException("Model must be initialized before starting training.")
     }
     data.foreachRDD { (rdd, time) =>
-        model = Option(algorithm.run(rdd, model.get.weights))
-        logInfo("Model updated at time %s".format(time.toString))
-        val display = model.get.weights.size match {
-          case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
-          case _ => model.get.weights.toArray.mkString("[", ",", "]")
+      val initialWeights =
+        model match {
+          case Some(m) =>
+            m.weights
+          case None =>
+            val numFeatures = rdd.first().features.size
+            Vectors.dense(numFeatures)
         }
-        logInfo("Current model: weights, %s".format (display))
+      model = Some(algorithm.run(rdd, initialWeights))
+      logInfo("Model updated at time %s".format(time.toString))
+      val display = model.get.weights.size match {
+        case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
+        case _ => model.get.weights.toArray.mkString("[", ",", "]")
+      }
+      logInfo("Current model: weights, %s".format (display))
     }
   }
 
@@ -99,10 +106,8 @@ abstract class StreamingLinearAlgorithm[
    * @return DStream containing predictions
    */
   def predictOn(data: DStream[Vector]): DStream[Double] = {
-    if (Option(model) == None) {
-      val msg = "Model must be initialized before starting prediction"
-      logError(msg)
-      throw new IllegalArgumentException(msg)
+    if (model.isEmpty) {
+      throw new IllegalArgumentException("Model must be initialized before starting prediction.")
     }
     data.map(model.get.predict)
   }
@@ -114,10 +119,8 @@ abstract class StreamingLinearAlgorithm[
    * @return DStream containing the input keys and the predictions as values
    */
   def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = {
-    if (Option(model) == None) {
-      val msg = "Model must be initialized before starting prediction"
-      logError(msg)
-      throw new IllegalArgumentException(msg)
+    if (model.isEmpty) {
+      throw new IllegalArgumentException("Model must be initialized before starting prediction")
     }
     data.mapValues(model.get.predict)
   }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
index c0625b4880..e5e6301127 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
@@ -79,7 +79,7 @@ class StreamingLinearRegressionWithSGD private[mllib] (
 
   /** Set the initial weights. Default: [0.0, 0.0]. */
   def setInitialWeights(initialWeights: Vector): this.type = {
-    this.model = Option(algorithm.createModel(initialWeights, 0.0))
+    this.model = Some(algorithm.createModel(initialWeights, 0.0))
     this
   }
 
-- 
GitLab