Skip to content
Snippets Groups Projects
Unverified Commit ae6cddb7 authored by actuaryzhang's avatar actuaryzhang Committed by Sean Owen
Browse files

[SPARK-18166][MLLIB] Fix Poisson GLM bug due to wrong requirement of response values

## What changes were proposed in this pull request?

The current implementation of Poisson GLM seems to allow only positive values. This is incorrect since the support of Poisson includes the origin. The bug is easily fixed by changing the test of the Poisson variable from  'require(y **>** 0.0' to  'require(y **>=** 0.0'.

mengxr  srowen

Author: actuaryzhang <actuaryzhang10@gmail.com>
Author: actuaryzhang <actuaryzhang@uber.com>

Closes #15683 from actuaryzhang/master.
parent f95b124c
No related branches found
No related tags found
No related merge requests found
......@@ -501,8 +501,8 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
val defaultLink: Link = Log
override def initialize(y: Double, weight: Double): Double = {
require(y > 0.0, "The response variable of Poisson family " +
s"should be positive, but got $y")
require(y >= 0.0, "The response variable of Poisson family " +
s"should be non-negative, but got $y")
y
}
......
......@@ -44,6 +44,7 @@ class GeneralizedLinearRegressionSuite
@transient var datasetGaussianInverse: DataFrame = _
@transient var datasetBinomial: DataFrame = _
@transient var datasetPoissonLog: DataFrame = _
@transient var datasetPoissonLogWithZero: DataFrame = _
@transient var datasetPoissonIdentity: DataFrame = _
@transient var datasetPoissonSqrt: DataFrame = _
@transient var datasetGammaInverse: DataFrame = _
......@@ -88,6 +89,12 @@ class GeneralizedLinearRegressionSuite
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
family = "poisson", link = "log").toDF()
datasetPoissonLogWithZero = generateGeneralizedLinearRegressionInput(
intercept = -1.5, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
xVariance = Array(0.7, 1.2), nPoints = 100, seed, noiseLevel = 0.01,
family = "poisson", link = "log")
.map{x => LabeledPoint(if (x.label < 0.7) 0.0 else x.label, x.features)}.toDF()
datasetPoissonIdentity = generateGeneralizedLinearRegressionInput(
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
......@@ -139,6 +146,10 @@ class GeneralizedLinearRegressionSuite
label + "," + features.toArray.mkString(",")
}.repartition(1).saveAsTextFile(
"target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonLog")
datasetPoissonLogWithZero.rdd.map { case Row(label: Double, features: Vector) =>
label + "," + features.toArray.mkString(",")
}.repartition(1).saveAsTextFile(
"target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonLogWithZero")
datasetPoissonIdentity.rdd.map { case Row(label: Double, features: Vector) =>
label + "," + features.toArray.mkString(",")
}.repartition(1).saveAsTextFile(
......@@ -456,6 +467,40 @@ class GeneralizedLinearRegressionSuite
}
}
test("generalized linear regression: poisson family against glm (with zero values)") {
/*
R code:
f1 <- data$V1 ~ data$V2 + data$V3 - 1
f2 <- data$V1 ~ data$V2 + data$V3
data <- read.csv("path", header=FALSE)
for (formula in c(f1, f2)) {
model <- glm(formula, family="poisson", data=data)
print(as.vector(coef(model)))
}
[1] 0.4272661 -0.1565423
[1] -3.6911354 0.6214301 0.1295814
*/
val expected = Seq(
Vectors.dense(0.0, 0.4272661, -0.1565423),
Vectors.dense(-3.6911354, 0.6214301, 0.1295814))
import GeneralizedLinearRegression._
var idx = 0
val link = "log"
val dataset = datasetPoissonLogWithZero
for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link)
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " +
s"$link link and fitIntercept = $fitIntercept (with zero values).")
idx += 1
}
}
test("generalized linear regression: gamma family against glm") {
/*
R code:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment