diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 4e812994405b329ee7984ebc0b1b358ca86e6fa9..94b0e00f37267a064b23b68aa28d86bca5a3c003 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -178,15 +178,16 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M // Use half as many iterations as the previous test. val lr = new LogisticRegressionWithSGD().setIntercept(true) lr.optimizer. - setStepSize(10.0). + setStepSize(1.0). setNumIterations(10). setRegParam(1.0) val model = lr.run(testRDD, initialWeights) // Test the weights - assert(model.weights(0) ~== -430000.0 relTol 20000.0) - assert(model.intercept ~== 370000.0 relTol 20000.0) + // With regularization, the resulting weights will be smaller. + assert(model.weights(0) ~== -0.14 relTol 0.02) + assert(model.intercept ~== 0.25 relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2)