Skip to content
Snippets Groups Projects
Commit 9bc3507e authored by Felix Cheung's avatar Felix Cheung Committed by Felix Cheung
Browse files

[SPARK-19133][SPARKR][ML] fix glm for Gamma, clarify glm family supported

## What changes were proposed in this pull request?

R family is a longer list than what Spark supports.

## How was this patch tested?

manual

Author: Felix Cheung <felixcheung_m@hotmail.com>

Closes #16511 from felixcheung/rdocglmfamily.
parent d5b1dc93
No related branches found
No related tags found
No related merge requests found
......@@ -52,6 +52,8 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
#' Currently these families are supported: \code{binomial}, \code{gaussian},
#' \code{Gamma}, and \code{poisson}.
#' @param tol positive convergence tolerance of iterations.
#' @param maxIter integer giving the maximal number of IRLS iterations.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
......@@ -104,8 +106,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
weightCol <- ""
}
# For known families, Gamma is upper-cased
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, family$family, family$link,
"fit", formula, data@sdf, tolower(family$family), family$link,
tol, as.integer(maxIter), as.character(weightCol), regParam)
new("GeneralizedLinearRegressionModel", jobj = jobj)
})
......@@ -120,6 +123,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
#' Currently these families are supported: \code{binomial}, \code{gaussian},
#' \code{Gamma}, and \code{poisson}.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param epsilon positive convergence tolerance of iterations.
......
......@@ -423,7 +423,7 @@ sparkR.session <- function(
#' sparkR.session()
#' url <- sparkR.uiWebUrl()
#' }
#' @note sparkR.uiWebUrl since 2.2.0
#' @note sparkR.uiWebUrl since 2.1.1
sparkR.uiWebUrl <- function() {
sc <- sparkR.callJMethod(getSparkContext(), "sc")
u <- callJMethod(sc, "uiWebUrl")
......
......@@ -61,14 +61,22 @@ test_that("spark.glm and predict", {
# poisson family
model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species,
family = poisson(link = identity))
family = poisson(link = identity))
prediction <- predict(model, training)
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
vals <- collect(select(prediction, "prediction"))
rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species,
data = iris, family = poisson(link = identity)), iris))
data = iris, family = poisson(link = identity)), iris))
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
# Gamma family
x <- runif(100, -1, 1)
y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10)
df <- as.DataFrame(as.data.frame(list(x = x, y = y)))
model <- glm(y ~ x, family = Gamma, df)
out <- capture.output(print(summary(model)))
expect_true(any(grepl("Dispersion parameter for gamma family", out)))
# Test stats::predict is working
x <- rnorm(15)
y <- x + rnorm(15)
......@@ -103,11 +111,11 @@ test_that("spark.glm summary", {
df <- suppressWarnings(createDataFrame(iris))
training <- df[df$Species %in% c("versicolor", "virginica"), ]
stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width,
family = binomial(link = "logit")))
family = binomial(link = "logit")))
rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ]
rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
family = binomial(link = "logit")))
family = binomial(link = "logit")))
coefs <- unlist(stats$coefficients)
rCoefs <- unlist(rStats$coefficients)
......@@ -222,7 +230,7 @@ test_that("glm and predict", {
training <- suppressWarnings(createDataFrame(iris))
# gaussian family
model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
prediction <- predict(model, training)
prediction <- predict(model, training)
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
vals <- collect(select(prediction, "prediction"))
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
......@@ -235,7 +243,7 @@ test_that("glm and predict", {
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
vals <- collect(select(prediction, "prediction"))
rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species,
data = iris, family = poisson(link = identity)), iris))
data = iris, family = poisson(link = identity)), iris))
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
# Test stats::predict is working
......@@ -268,11 +276,11 @@ test_that("glm summary", {
df <- suppressWarnings(createDataFrame(iris))
training <- df[df$Species %in% c("versicolor", "virginica"), ]
stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
family = binomial(link = "logit")))
family = binomial(link = "logit")))
rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ]
rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
family = binomial(link = "logit")))
family = binomial(link = "logit")))
coefs <- unlist(stats$coefficients)
rCoefs <- unlist(rStats$coefficients)
......@@ -409,7 +417,7 @@ test_that("spark.survreg", {
x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
expect_error(
model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData),
NA)
NA)
expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4)
}
})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment