Skip to content
Snippets Groups Projects
Commit fba9e954 authored by Yanbo Liang's avatar Yanbo Liang Committed by Xiangrui Meng
Browse files

[SPARK-11369][ML][R] SparkR glm should support setting standardize

SparkR glm currently support :
```formula, family = c(“gaussian”, “binomial”), data, lambda = 0, alpha = 0```
We should also support setting standardize which has been defined at [design documentation](https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit)

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #9331 from yanboliang/spark-11369.
parent fd9e345c
No related branches found
No related tags found
No related merge requests found
...@@ -46,11 +46,11 @@ setClass("PipelineModel", representation(model = "jobj")) ...@@ -46,11 +46,11 @@ setClass("PipelineModel", representation(model = "jobj"))
#'} #'}
setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0,
solver = "auto") { standardize = TRUE, solver = "auto") {
family <- match.arg(family) family <- match.arg(family)
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"fitRModelFormula", deparse(formula), data@sdf, family, lambda, "fitRModelFormula", deparse(formula), data@sdf, family, lambda,
alpha, solver) alpha, standardize, solver)
return(new("PipelineModel", model = model)) return(new("PipelineModel", model = model))
}) })
......
...@@ -31,6 +31,7 @@ private[r] object SparkRWrappers { ...@@ -31,6 +31,7 @@ private[r] object SparkRWrappers {
family: String, family: String,
lambda: Double, lambda: Double,
alpha: Double, alpha: Double,
standardize: Boolean,
solver: String): PipelineModel = { solver: String): PipelineModel = {
val formula = new RFormula().setFormula(value) val formula = new RFormula().setFormula(value)
val estimator = family match { val estimator = family match {
...@@ -38,11 +39,13 @@ private[r] object SparkRWrappers { ...@@ -38,11 +39,13 @@ private[r] object SparkRWrappers {
.setRegParam(lambda) .setRegParam(lambda)
.setElasticNetParam(alpha) .setElasticNetParam(alpha)
.setFitIntercept(formula.hasIntercept) .setFitIntercept(formula.hasIntercept)
.setStandardization(standardize)
.setSolver(solver) .setSolver(solver)
case "binomial" => new LogisticRegression() case "binomial" => new LogisticRegression()
.setRegParam(lambda) .setRegParam(lambda)
.setElasticNetParam(alpha) .setElasticNetParam(alpha)
.setFitIntercept(formula.hasIntercept) .setFitIntercept(formula.hasIntercept)
.setStandardization(standardize)
} }
val pipeline = new Pipeline().setStages(Array(formula, estimator)) val pipeline = new Pipeline().setStages(Array(formula, estimator))
pipeline.fit(df) pipeline.fit(df)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment