From e3e9c70384028cc0c322ccea14f19d3b6d6b39eb Mon Sep 17 00:00:00 2001 From: MechCoder <manojkumarsivaraj334@gmail.com> Date: Mon, 8 Jun 2015 15:45:12 +0100 Subject: [PATCH] [SPARK-8140] [MLLIB] Remove empty model check in StreamingLinearAlgorithm 1. Prevent creating a map of data to find numFeatures 2. If model is empty, then initialize with a zero vector of numFeature Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #6684 from MechCoder/spark-8140 and squashes the following commits: 7fbf5f9 [MechCoder] [SPARK-8140] Remove empty model check in StreamingLinearAlgorithm And other minor cosmits --- .../apache/spark/mllib/optimization/GradientDescent.scala | 2 +- .../spark/mllib/regression/GeneralizedLinearAlgorithm.scala | 6 +++--- .../spark/mllib/regression/StreamingLinearAlgorithm.scala | 3 --- .../mllib/regression/StreamingLinearRegressionWithSGD.scala | 2 +- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 4b7d0589c9..06e45e10c5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -179,7 +179,7 @@ object GradientDescent extends Logging { * if it's L2 updater; for L1 updater, the same logic is followed. */ var regVal = updater.compute( - weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 + weights, Vectors.zeros(weights.size), 0, 1, regParam)._2 for (i <- 1 to numIterations) { val bcWeights = data.context.broadcast(weights) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 26be30ff9d..6709bd79bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -195,11 +195,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ val initialWeights = { if (numOfLinearPredictor == 1) { - Vectors.dense(new Array[Double](numFeatures)) + Vectors.zeros(numFeatures) } else if (addIntercept) { - Vectors.dense(new Array[Double]((numFeatures + 1) * numOfLinearPredictor)) + Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) } else { - Vectors.dense(new Array[Double](numFeatures * numOfLinearPredictor)) + Vectors.zeros(numFeatures * numOfLinearPredictor) } } run(input, initialWeights) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index cea8f3f473..39308e5ae1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -87,9 +87,6 @@ abstract class StreamingLinearAlgorithm[ model match { case Some(m) => m.weights - case None => - val numFeatures = rdd.first().features.size - Vectors.dense(numFeatures) } model = Some(algorithm.run(rdd, initialWeights)) logInfo("Model updated at time %s".format(time.toString)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index a49153bf73..235e043c77 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -79,7 +79,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( this } - /** Set the initial weights. Default: [0.0, 0.0]. */ + /** Set the initial weights. */ def setInitialWeights(initialWeights: Vector): this.type = { this.model = Some(algorithm.createModel(initialWeights, 0.0)) this -- GitLab