From 1022049c78e55914c54dff6d5206ad56dba7eef4 Mon Sep 17 00:00:00 2001
From: Felix Cheung <felixcheung_m@hotmail.com>
Date: Tue, 10 Jan 2017 21:22:16 -0800
Subject: [PATCH] [SPARK-19133][SPARKR][ML][BACKPORT-2.1] fix glm for Gamma,
 clarify glm family supported

## What changes were proposed in this pull request?

backporting to 2.1, 2.0 and 1.6

## How was this patch tested?

unit tests

Author: Felix Cheung <felixcheung_m@hotmail.com>

Closes #16532 from felixcheung/rgammabackport.
---
 R/pkg/R/mllib.R                        | 7 ++++++-
 R/pkg/inst/tests/testthat/test_mllib.R | 8 ++++++++
 2 files changed, 14 insertions(+), 1 deletion(-)

diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index d736bbb5e9..1a254ad49b 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -184,6 +184,8 @@ predict_internal <- function(object, newData) {
 #'               This can be a character string naming a family function, a family function or
 #'               the result of a call to a family function. Refer R family at
 #'               \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
+#'               Currently these families are supported: \code{binomial}, \code{gaussian},
+#'               \code{Gamma}, and \code{poisson}.
 #' @param tol positive convergence tolerance of iterations.
 #' @param maxIter integer giving the maximal number of IRLS iterations.
 #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
@@ -236,8 +238,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
               weightCol <- ""
             }
 
+            # For known families, Gamma is upper-cased
             jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
-                                "fit", formula, data@sdf, family$family, family$link,
+                                "fit", formula, data@sdf, tolower(family$family), family$link,
                                 tol, as.integer(maxIter), as.character(weightCol), regParam)
             new("GeneralizedLinearRegressionModel", jobj = jobj)
           })
@@ -252,6 +255,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
 #'               This can be a character string naming a family function, a family function or
 #'               the result of a call to a family function. Refer R family at
 #'               \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
+#'               Currently these families are supported: \code{binomial}, \code{gaussian},
+#'               \code{Gamma}, and \code{poisson}.
 #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
 #'                  weights as 1.0.
 #' @param epsilon positive convergence tolerance of iterations.
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 40c0446740..1f2fae9c81 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -74,6 +74,14 @@ test_that("spark.glm and predict", {
   data = iris, family = poisson(link = identity)), iris))
   expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
 
+  # Gamma family
+  x <- runif(100, -1, 1)
+  y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10)
+  df <- as.DataFrame(as.data.frame(list(x = x, y = y)))
+  model <- glm(y ~ x, family = Gamma, df)
+  out <- capture.output(print(summary(model)))
+  expect_true(any(grepl("Dispersion parameter for gamma family", out)))
+
   # Test stats::predict is working
   x <- rnorm(15)
   y <- x + rnorm(15)
-- 
GitLab