diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 770a2571bb9c2f662b61c0675decf89f725d9c53..f137c8cb418944283705d577b4571a425a8a5335 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -505,7 +505,11 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def initialize(y: Double, weight: Double): Double = { require(y >= 0.0, "The response variable of Poisson family " + s"should be non-negative, but got $y") - y + /* + Force Poisson mean > 0 to avoid numerical instability in IRLS. + R uses y + 0.1 for initialization. See poisson()$initialize. + */ + math.max(y, 0.1) } override def variance(mu: Double): Double = mu diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 4fab2160339c6a6acf1ac07fe7a53f77d6d7c87c..3e9e1fced8ec4d1428c46da03ee8f39bb205a0ae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -89,11 +89,14 @@ class GeneralizedLinearRegressionSuite xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "log").toDF() - datasetPoissonLogWithZero = generateGeneralizedLinearRegressionInput( - intercept = -1.5, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 100, seed, noiseLevel = 0.01, - family = "poisson", link = "log") - .map{x => LabeledPoint(if (x.label < 0.7) 0.0 else x.label, x.features)}.toDF() + datasetPoissonLogWithZero = Seq( + LabeledPoint(0.0, Vectors.dense(18, 1.0)), + LabeledPoint(1.0, Vectors.dense(12, 0.0)), + LabeledPoint(0.0, Vectors.dense(15, 0.0)), + LabeledPoint(0.0, Vectors.dense(13, 2.0)), + LabeledPoint(0.0, Vectors.dense(15, 1.0)), + LabeledPoint(1.0, Vectors.dense(16, 1.0)) + ).toDF() datasetPoissonIdentity = generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), @@ -480,12 +483,12 @@ class GeneralizedLinearRegressionSuite model <- glm(formula, family="poisson", data=data) print(as.vector(coef(model))) } - [1] 0.4272661 -0.1565423 - [1] -3.6911354 0.6214301 0.1295814 + [1] -0.0457441 -0.6833928 + [1] 1.8121235 -0.1747493 -0.5815417 */ val expected = Seq( - Vectors.dense(0.0, 0.4272661, -0.1565423), - Vectors.dense(-3.6911354, 0.6214301, 0.1295814)) + Vectors.dense(0.0, -0.0457441, -0.6833928), + Vectors.dense(1.8121235, -0.1747493, -0.5815417)) import GeneralizedLinearRegression._