Skip to content
Snippets Groups Projects
Commit 10c78607 authored by Yanbo Liang's avatar Yanbo Liang Committed by Sean Owen
Browse files

[SPARK-6496] [MLLIB] GeneralizedLinearAlgorithm.run(input, initialWeights)...

[SPARK-6496] [MLLIB] GeneralizedLinearAlgorithm.run(input, initialWeights) should initialize numFeatures

In GeneralizedLinearAlgorithm ```numFeatures``` is default to -1, we need to update it to correct value when we call run() to train a model.
```LogisticRegressionWithLBFGS.run(input)``` works well, but when we call ```LogisticRegressionWithLBFGS.run(input, initialWeights)``` to train multiclass classification model, it will throw exception due to the numFeatures is not updated.
In this PR, we just update numFeatures at the beginning of GeneralizedLinearAlgorithm.run(input, initialWeights) and add test case.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #5167 from yanboliang/spark-6496 and squashes the following commits:

8131c48 [Yanbo Liang] LogisticRegressionWithLBFGS.run(input, initialWeights) should initialize numFeatures
parent 64262ed9
No related branches found
No related tags found
No related merge requests found
...@@ -211,6 +211,10 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] ...@@ -211,6 +211,10 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/ */
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
if (numFeatures < 0) {
numFeatures = input.map(_.features.size).first()
}
if (input.getStorageLevel == StorageLevel.NONE) { if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its" logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.") + " parent RDDs are also uncached.")
......
...@@ -425,6 +425,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M ...@@ -425,6 +425,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
val model = lr.run(testRDD) val model = lr.run(testRDD)
val numFeatures = testRDD.map(_.features.size).first()
val initialWeights = Vectors.dense(new Array[Double]((numFeatures + 1) * 2))
val model2 = lr.run(testRDD, initialWeights)
LogisticRegressionSuite.checkModelsEqual(model, model2)
/** /**
* The following is the instruction to reproduce the model using R's glmnet package. * The following is the instruction to reproduce the model using R's glmnet package.
* *
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment