Skip to content
Snippets Groups Projects
Commit e3e9c703 authored by MechCoder's avatar MechCoder Committed by Sean Owen
Browse files

[SPARK-8140] [MLLIB] Remove empty model check in StreamingLinearAlgorithm

1. Prevent creating a map of data to find numFeatures
2. If model is empty, then initialize with a zero vector of numFeature

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #6684 from MechCoder/spark-8140 and squashes the following commits:

7fbf5f9 [MechCoder] [SPARK-8140] Remove empty model check in StreamingLinearAlgorithm And other minor cosmits
parent a1d9e5cc
No related branches found
No related tags found
No related merge requests found
...@@ -179,7 +179,7 @@ object GradientDescent extends Logging { ...@@ -179,7 +179,7 @@ object GradientDescent extends Logging {
* if it's L2 updater; for L1 updater, the same logic is followed. * if it's L2 updater; for L1 updater, the same logic is followed.
*/ */
var regVal = updater.compute( var regVal = updater.compute(
weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 weights, Vectors.zeros(weights.size), 0, 1, regParam)._2
for (i <- 1 to numIterations) { for (i <- 1 to numIterations) {
val bcWeights = data.context.broadcast(weights) val bcWeights = data.context.broadcast(weights)
......
...@@ -195,11 +195,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] ...@@ -195,11 +195,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/ */
val initialWeights = { val initialWeights = {
if (numOfLinearPredictor == 1) { if (numOfLinearPredictor == 1) {
Vectors.dense(new Array[Double](numFeatures)) Vectors.zeros(numFeatures)
} else if (addIntercept) { } else if (addIntercept) {
Vectors.dense(new Array[Double]((numFeatures + 1) * numOfLinearPredictor)) Vectors.zeros((numFeatures + 1) * numOfLinearPredictor)
} else { } else {
Vectors.dense(new Array[Double](numFeatures * numOfLinearPredictor)) Vectors.zeros(numFeatures * numOfLinearPredictor)
} }
} }
run(input, initialWeights) run(input, initialWeights)
......
...@@ -87,9 +87,6 @@ abstract class StreamingLinearAlgorithm[ ...@@ -87,9 +87,6 @@ abstract class StreamingLinearAlgorithm[
model match { model match {
case Some(m) => case Some(m) =>
m.weights m.weights
case None =>
val numFeatures = rdd.first().features.size
Vectors.dense(numFeatures)
} }
model = Some(algorithm.run(rdd, initialWeights)) model = Some(algorithm.run(rdd, initialWeights))
logInfo("Model updated at time %s".format(time.toString)) logInfo("Model updated at time %s".format(time.toString))
......
...@@ -79,7 +79,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( ...@@ -79,7 +79,7 @@ class StreamingLinearRegressionWithSGD private[mllib] (
this this
} }
/** Set the initial weights. Default: [0.0, 0.0]. */ /** Set the initial weights. */
def setInitialWeights(initialWeights: Vector): this.type = { def setInitialWeights(initialWeights: Vector): this.type = {
this.model = Some(algorithm.createModel(initialWeights, 0.0)) this.model = Some(algorithm.createModel(initialWeights, 0.0))
this this
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment