diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 11940d356039edc308bac87dda8b4af5e4a95101..647cbbdd825e357f33567d447bcc8be6e4eb013e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1406,7 +1406,7 @@ setGeneric("spark.randomForest", #' @rdname spark.survreg #' @export -setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) +setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) #' @rdname spark.svmLinear #' @export diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index fa0d795faa10fcb55137b13ce4efce9a111b3c53..05bb95266173a2157bc785501a00b3a2447d6dd2 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -207,6 +207,9 @@ function(object, path, overwrite = FALSE) { #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p #' is the original probability of that class and t is the class's threshold. #' @param weightCol The weight column name. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features +#' or the number of partitions are large, this param could be adjusted to a larger size. +#' This is an expert parameter. Default value should be good for most cases. #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -245,11 +248,13 @@ function(object, path, overwrite = FALSE) { setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, tol = 1E-6, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL) { + thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) { formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", @@ -257,7 +262,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") as.numeric(elasticNetParam), as.integer(maxIter), as.numeric(tol), as.character(family), as.logical(standardization), as.array(thresholds), - as.character(weightCol)) + weightCol, as.integer(aggregationDepth)) new("LogisticRegressionModel", jobj = jobj) }) diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 96ee220bc41137e06a42a8e388076bc5b6edae0f..ac0578c4ab2599a530086e15f001e12441951488 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -102,14 +102,16 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), } formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } # For known families, Gamma is upper-cased jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", "fit", formula, data@sdf, tolower(family$family), family$link, - tol, as.integer(maxIter), as.character(weightCol), regParam) + tol, as.integer(maxIter), weightCol, regParam) new("GeneralizedLinearRegressionModel", jobj = jobj) }) @@ -305,13 +307,15 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula" function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) { formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit", data@sdf, formula, as.logical(isotonic), as.integer(featureIndex), - as.character(weightCol)) + weightCol) new("IsotonicRegressionModel", jobj = jobj) }) @@ -372,6 +376,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. #' Note that operator '.' is not supported currently. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features +#' or the number of partitions are large, this param could be adjusted to a larger size. +#' This is an expert parameter. Default value should be good for most cases. +#' @param ... additional arguments passed to the method. #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg #' @seealso survival: \url{https://cran.r-project.org/package=survival} @@ -396,10 +404,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' } #' @note spark.survreg since 2.0.0 setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula) { + function(data, formula, aggregationDepth = 2) { formula <- paste(deparse(formula), collapse = "") jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", - "fit", formula, data@sdf) + "fit", formula, data@sdf, as.integer(aggregationDepth)) new("AFTSurvivalRegressionModel", jobj = jobj) }) diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index 620f528f2e6c8009367f28aa82887f87f5c6ec4b..459254d271a58b5de7406d0f3a7da858a727afc5 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -211,7 +211,15 @@ test_that("spark.logit", { df <- createDataFrame(data) model <- spark.logit(df, label ~ feature) prediction <- collect(select(predict(model, df), "prediction")) - expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0")) + expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0")) + + # Test prediction with weightCol + weight <- c(2.0, 2.0, 2.0, 1.0, 1.0) + data2 <- as.data.frame(cbind(label, feature, weight)) + df2 <- createDataFrame(data2) + model2 <- spark.logit(df2, label ~ feature, weightCol = "weight") + prediction2 <- collect(select(predict(model2, df2), "prediction")) + expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0")) }) test_that("spark.mlp", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index bd965acf56944d3708b1ff3dc6e325901ff4340e..0bf543d88894ec34b532442834ca8322cba10868 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -82,7 +82,10 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg } - def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = { + def fit( + formula: String, + data: DataFrame, + aggregationDepth: Int): AFTSurvivalRegressionWrapper = { val (rewritedFormula, censorCol) = formulaRewrite(formula) @@ -100,6 +103,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg .setCensorCol(censorCol) .setFitIntercept(rFormula.hasIntercept) .setFeaturesCol(rFormula.getFeaturesCol) + .setAggregationDepth(aggregationDepth) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, aft)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 78f401f29b004f5288c08ec51984602266aa7be6..cbd6cd1c7933c0fd2cbff999aad2f286ded8d911 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -87,9 +87,11 @@ private[r] object GeneralizedLinearRegressionWrapper .setFitIntercept(rFormula.hasIntercept) .setTol(tol) .setMaxIter(maxIter) - .setWeightCol(weightCol) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) + + if (weightCol != null) glr.setWeightCol(weightCol) + val pipeline = new Pipeline() .setStages(Array(rFormulaModel, glr)) .fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala index 48632316f39508afc43921e39bcb30339e8ebac8..d31ebb46afb9704c2796e0cc75c6ad11e15cab90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala @@ -74,9 +74,10 @@ private[r] object IsotonicRegressionWrapper val isotonicRegression = new IsotonicRegression() .setIsotonic(isotonic) .setFeatureIndex(featureIndex) - .setWeightCol(weightCol) .setFeaturesCol(rFormula.getFeaturesCol) + if (weightCol != null) isotonicRegression.setWeightCol(weightCol) + val pipeline = new Pipeline() .setStages(Array(rFormulaModel, isotonicRegression)) .fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index 645bc7247f30f5b50d07a40bc6f18f2a77e3296d..c96f99cb83434a32bd19717fc84fbe604c79c9bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -96,7 +96,8 @@ private[r] object LogisticRegressionWrapper family: String, standardization: Boolean, thresholds: Array[Double], - weightCol: String + weightCol: String, + aggregationDepth: Int ): LogisticRegressionWrapper = { val rFormula = new RFormula() @@ -119,10 +120,10 @@ private[r] object LogisticRegressionWrapper .setFitIntercept(fitIntercept) .setFamily(family) .setStandardization(standardization) - .setWeightCol(weightCol) .setFeaturesCol(rFormula.getFeaturesCol) .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) + .setAggregationDepth(aggregationDepth) if (thresholds.length > 1) { lr.setThresholds(thresholds) @@ -130,6 +131,8 @@ private[r] object LogisticRegressionWrapper lr.setThreshold(thresholds(0)) } + if (weightCol != null) lr.setWeightCol(weightCol) + val idxToStr = new IndexToString() .setInputCol(PREDICTED_LABEL_INDEX_COL) .setOutputCol(PREDICTED_LABEL_COL)