diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala index eabd2162e287f77d465a7a140743abb0196b27ad..6a3893d0e41d2240861712ae1100f717cd7a6050 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -88,8 +88,7 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( /** Set the initial weights. Default: [0.0, 0.0]. */ def setInitialWeights(initialWeights: Vector): this.type = { - this.model = Option(algorithm.createModel(initialWeights, 0.0)) + this.model = Some(algorithm.createModel(initialWeights, 0.0)) this } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index 39a0dee931d3d69f97561fd11d62402d24d214d5..44a8dbb994cfb35945c2a5dca856ba904f2380d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.streaming.dstream.DStream /** @@ -58,7 +58,7 @@ abstract class StreamingLinearAlgorithm[ A <: GeneralizedLinearAlgorithm[M]] extends Logging { /** The model to be updated and used for prediction. */ - protected var model: Option[M] = null + protected var model: Option[M] = None /** The algorithm to use for updating. */ protected val algorithm: A @@ -77,18 +77,25 @@ abstract class StreamingLinearAlgorithm[ * @param data DStream containing labeled data */ def trainOn(data: DStream[LabeledPoint]) { - if (Option(model) == None) { - logError("Model must be initialized before starting training") - throw new IllegalArgumentException + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting training.") } data.foreachRDD { (rdd, time) => - model = Option(algorithm.run(rdd, model.get.weights)) - logInfo("Model updated at time %s".format(time.toString)) - val display = model.get.weights.size match { - case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") - case _ => model.get.weights.toArray.mkString("[", ",", "]") + val initialWeights = + model match { + case Some(m) => + m.weights + case None => + val numFeatures = rdd.first().features.size + Vectors.dense(numFeatures) } - logInfo("Current model: weights, %s".format (display)) + model = Some(algorithm.run(rdd, initialWeights)) + logInfo("Model updated at time %s".format(time.toString)) + val display = model.get.weights.size match { + case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") + case _ => model.get.weights.toArray.mkString("[", ",", "]") + } + logInfo("Current model: weights, %s".format (display)) } } @@ -99,10 +106,8 @@ abstract class StreamingLinearAlgorithm[ * @return DStream containing predictions */ def predictOn(data: DStream[Vector]): DStream[Double] = { - if (Option(model) == None) { - val msg = "Model must be initialized before starting prediction" - logError(msg) - throw new IllegalArgumentException(msg) + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting prediction.") } data.map(model.get.predict) } @@ -114,10 +119,8 @@ abstract class StreamingLinearAlgorithm[ * @return DStream containing the input keys and the predictions as values */ def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = { - if (Option(model) == None) { - val msg = "Model must be initialized before starting prediction" - logError(msg) - throw new IllegalArgumentException(msg) + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting prediction") } data.mapValues(model.get.predict) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index c0625b4880953f5d2b31085e5eae9d96423bd99f..e5e6301127a286d5ca92ae5c9610b205d3bf099f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -79,7 +79,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( /** Set the initial weights. Default: [0.0, 0.0]. */ def setInitialWeights(initialWeights: Vector): this.type = { - this.model = Option(algorithm.createModel(initialWeights, 0.0)) + this.model = Some(algorithm.createModel(initialWeights, 0.0)) this }