Skip to content
Snippets Groups Projects
Commit 59a49db5 authored by DB Tsai's avatar DB Tsai Committed by Xiangrui Meng
Browse files

[SPARK-4887][MLlib] Fix a bad unittest in LogisticRegressionSuite

The original test doesn't make sense since if you step in, the lossSum is already NaN,
and the coefficients are diverging. That's because the step size is too large for SGD,
so it doesn't work.

The correct behavior is that you should get smaller coefficients than the one
without regularization. Comparing the values using 20000.0 relative error doesn't
make sense as well.

Author: DB Tsai <dbtsai@alpinenow.com>

Closes #3735 from dbtsai/mlortestfix and squashes the following commits:

b1a3c42 [DB Tsai] first commit
parent 3720057b
No related branches found
No related tags found
No related merge requests found
...@@ -178,15 +178,16 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M ...@@ -178,15 +178,16 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
// Use half as many iterations as the previous test. // Use half as many iterations as the previous test.
val lr = new LogisticRegressionWithSGD().setIntercept(true) val lr = new LogisticRegressionWithSGD().setIntercept(true)
lr.optimizer. lr.optimizer.
setStepSize(10.0). setStepSize(1.0).
setNumIterations(10). setNumIterations(10).
setRegParam(1.0) setRegParam(1.0)
val model = lr.run(testRDD, initialWeights) val model = lr.run(testRDD, initialWeights)
// Test the weights // Test the weights
assert(model.weights(0) ~== -430000.0 relTol 20000.0) // With regularization, the resulting weights will be smaller.
assert(model.intercept ~== 370000.0 relTol 20000.0) assert(model.weights(0) ~== -0.14 relTol 0.02)
assert(model.intercept ~== 0.25 relTol 0.02)
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2) val validationRDD = sc.parallelize(validationData, 2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment