Skip to content
Snippets Groups Projects
Commit 7a5000f3 authored by Xin Ren's avatar Xin Ren Committed by Shivaram Venkataraman
Browse files

[SPARK-17241][SPARKR][MLLIB] SparkR spark.glm should have configurable regularization parameter

https://issues.apache.org/jira/browse/SPARK-17241

## What changes were proposed in this pull request?

Spark has configurable L2 regularization parameter for generalized linear regression. It is very important to have them in SparkR so that users can run ridge regression.

## How was this patch tested?

Test manually on local laptop.

Author: Xin Ren <iamshrek@126.com>

Closes #14856 from keypointt/SPARK-17241.
parent d008638f
No related branches found
No related tags found
No related merge requests found
......@@ -138,10 +138,11 @@ predict_internal <- function(object, newData) {
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param tol positive convergence tolerance of iterations.
#' @param maxIter integer giving the maximal number of IRLS iterations.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param regParam regularization parameter for L2 regularization.
#' @param ... additional arguments passed to the method.
#' @aliases spark.glm,SparkDataFrame,formula-method
#' @return \code{spark.glm} returns a fitted generalized linear model
......@@ -171,7 +172,8 @@ predict_internal <- function(object, newData) {
#' @note spark.glm since 2.0.0
#' @seealso \link{glm}, \link{read.ml}
setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL) {
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL,
regParam = 0.0) {
if (is.character(family)) {
family <- get(family, mode = "function", envir = parent.frame())
}
......@@ -190,7 +192,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, family$family, family$link,
tol, as.integer(maxIter), as.character(weightCol))
tol, as.integer(maxIter), as.character(weightCol), regParam)
new("GeneralizedLinearRegressionModel", jobj = jobj)
})
......
......@@ -148,6 +148,12 @@ test_that("spark.glm summary", {
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
baseSummary <- summary(baseModel)
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
# Test spark.glm works with regularization parameter
data <- as.data.frame(cbind(a1, a2, b))
df <- suppressWarnings(createDataFrame(data))
regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0))
expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result
})
test_that("spark.glm save/load", {
......
......@@ -69,7 +69,8 @@ private[r] object GeneralizedLinearRegressionWrapper
link: String,
tol: Double,
maxIter: Int,
weightCol: String): GeneralizedLinearRegressionWrapper = {
weightCol: String,
regParam: Double): GeneralizedLinearRegressionWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
val rFormulaModel = rFormula.fit(data)
......@@ -86,6 +87,7 @@ private[r] object GeneralizedLinearRegressionWrapper
.setTol(tol)
.setMaxIter(maxIter)
.setWeightCol(weightCol)
.setRegParam(regParam)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glr))
.fit(data)
......
......@@ -1034,6 +1034,46 @@ class GeneralizedLinearRegressionSuite
.setFamily("gaussian")
.fit(datasetGaussianIdentity.as[LabeledPoint])
}
test("generalized linear regression: regularization parameter") {
/*
R code:
a1 <- c(0, 1, 2, 3)
a2 <- c(5, 2, 1, 3)
b <- c(1, 0, 1, 0)
data <- as.data.frame(cbind(a1, a2, b))
df <- suppressWarnings(createDataFrame(data))
for (regParam in c(0.0, 0.1, 1.0)) {
model <- spark.glm(df, b ~ a1 + a2, regParam = regParam)
print(as.vector(summary(model)$aic))
}
[1] 12.88188
[1] 12.92681
[1] 13.32836
*/
val dataset = spark.createDataFrame(Seq(
LabeledPoint(1, Vectors.dense(5, 0)),
LabeledPoint(0, Vectors.dense(2, 1)),
LabeledPoint(1, Vectors.dense(1, 2)),
LabeledPoint(0, Vectors.dense(3, 3))
))
val expected = Seq(12.88188, 12.92681, 13.32836)
var idx = 0
for (regParam <- Seq(0.0, 0.1, 1.0)) {
val trainer = new GeneralizedLinearRegression()
.setRegParam(regParam)
.setLabelCol("label")
.setFeaturesCol("features")
val model = trainer.fit(dataset)
val actual = model.summary.aic
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with regParam = $regParam.")
idx += 1
}
}
}
object GeneralizedLinearRegressionSuite {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment