Skip to content
Snippets Groups Projects
Unverified Commit b8280271 authored by actuaryzhang's avatar actuaryzhang Committed by Sean Owen
Browse files

[SPARK-18701][ML] Fix Poisson GLM failure due to wrong initialization

Poisson GLM fails for many standard data sets (see example in test or JIRA). The issue is incorrect initialization leading to almost zero probability and weights. Specifically, the mean is initialized as the response, which could be zero. Applying the log link results in very negative numbers (protected against -Inf), which again leads to close to zero probability and weights in the weighted least squares. Fix and test are included in the commits.

## What changes were proposed in this pull request?
Update initialization in Poisson GLM

## How was this patch tested?
Add test in GeneralizedLinearRegressionSuite

srowen sethah yanboliang HyukjinKwon mengxr

Author: actuaryzhang <actuaryzhang10@gmail.com>

Closes #16131 from actuaryzhang/master.
parent 90b59d1b
No related branches found
No related tags found
No related merge requests found
......@@ -505,7 +505,11 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def initialize(y: Double, weight: Double): Double = {
require(y >= 0.0, "The response variable of Poisson family " +
s"should be non-negative, but got $y")
y
/*
Force Poisson mean > 0 to avoid numerical instability in IRLS.
R uses y + 0.1 for initialization. See poisson()$initialize.
*/
math.max(y, 0.1)
}
override def variance(mu: Double): Double = mu
......
......@@ -89,11 +89,14 @@ class GeneralizedLinearRegressionSuite
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
family = "poisson", link = "log").toDF()
datasetPoissonLogWithZero = generateGeneralizedLinearRegressionInput(
intercept = -1.5, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
xVariance = Array(0.7, 1.2), nPoints = 100, seed, noiseLevel = 0.01,
family = "poisson", link = "log")
.map{x => LabeledPoint(if (x.label < 0.7) 0.0 else x.label, x.features)}.toDF()
datasetPoissonLogWithZero = Seq(
LabeledPoint(0.0, Vectors.dense(18, 1.0)),
LabeledPoint(1.0, Vectors.dense(12, 0.0)),
LabeledPoint(0.0, Vectors.dense(15, 0.0)),
LabeledPoint(0.0, Vectors.dense(13, 2.0)),
LabeledPoint(0.0, Vectors.dense(15, 1.0)),
LabeledPoint(1.0, Vectors.dense(16, 1.0))
).toDF()
datasetPoissonIdentity = generateGeneralizedLinearRegressionInput(
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
......@@ -480,12 +483,12 @@ class GeneralizedLinearRegressionSuite
model <- glm(formula, family="poisson", data=data)
print(as.vector(coef(model)))
}
[1] 0.4272661 -0.1565423
[1] -3.6911354 0.6214301 0.1295814
[1] -0.0457441 -0.6833928
[1] 1.8121235 -0.1747493 -0.5815417
*/
val expected = Seq(
Vectors.dense(0.0, 0.4272661, -0.1565423),
Vectors.dense(-3.6911354, 0.6214301, 0.1295814))
Vectors.dense(0.0, -0.0457441, -0.6833928),
Vectors.dense(1.8121235, -0.1747493, -0.5815417))
import GeneralizedLinearRegression._
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment