Skip to content
Snippets Groups Projects
Commit 1e07a719 authored by actuaryzhang's avatar actuaryzhang Committed by Yanbo Liang
Browse files

[SPARK-19155][ML] Make family case insensitive in GLM


## What changes were proposed in this pull request?
This is a supplement to PR #16516 which did not make the value from `getFamily` case insensitive. Current tests of poisson/binomial glm with weight fail when specifying 'Poisson' or 'Binomial', because the calculation of `dispersion` and `pValue` checks the value of family retrieved from `getFamily`
```
model.getFamily == Binomial.name || model.getFamily == Poisson.name
```

## How was this patch tested?
Update existing tests for 'Poisson' and 'Binomial'.

yanboliang felixcheung imatiach-msft

Author: actuaryzhang <actuaryzhang10@gmail.com>

Closes #16675 from actuaryzhang/family.

(cherry picked from commit f067acef)
Signed-off-by: default avatarYanbo Liang <ybliang8@gmail.com>
parent 8daf10e3
No related branches found
No related tags found
No related merge requests found
......@@ -1027,7 +1027,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
*/
@Since("2.0.0")
lazy val dispersion: Double = if (
model.getFamily == Binomial.name || model.getFamily == Poisson.name) {
model.getFamily.toLowerCase == Binomial.name ||
model.getFamily.toLowerCase == Poisson.name) {
1.0
} else {
val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0)
......@@ -1130,7 +1131,8 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] (
@Since("2.0.0")
lazy val pValues: Array[Double] = {
if (isNormalSolver) {
if (model.getFamily == Binomial.name || model.getFamily == Poisson.name) {
if (model.getFamily.toLowerCase == Binomial.name ||
model.getFamily.toLowerCase == Poisson.name) {
tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) }
} else {
tValues.map { x =>
......
......@@ -757,7 +757,7 @@ class GeneralizedLinearRegressionSuite
0.5554219 -0.4034267 0.6567520 -0.2611382
*/
val trainer = new GeneralizedLinearRegression()
.setFamily("binomial")
.setFamily("Binomial")
.setWeightCol("weight")
.setFitIntercept(false)
......@@ -874,7 +874,7 @@ class GeneralizedLinearRegressionSuite
-0.4378554 0.2189277 0.1459518 -0.1094638
*/
val trainer = new GeneralizedLinearRegression()
.setFamily("poisson")
.setFamily("Poisson")
.setWeightCol("weight")
.setFitIntercept(true)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment