From 55964c15a7b639f920dfe6c104ae4fdcd673705c Mon Sep 17 00:00:00 2001 From: Felix Cheung <felixcheung_m@hotmail.com> Date: Tue, 8 Nov 2016 16:00:45 -0800 Subject: [PATCH] [SPARK-18239][SPARKR] Gradient Boosted Tree for R ## What changes were proposed in this pull request? Gradient Boosted Tree in R. With a few minor improvements to RandomForest in R. Since this is relatively isolated I'd like to target this for branch-2.1 ## How was this patch tested? manual tests, unit tests Author: Felix Cheung <felixcheung_m@hotmail.com> Closes #15746 from felixcheung/rgbt. --- R/pkg/NAMESPACE | 9 +- R/pkg/R/generics.R | 4 + R/pkg/R/mllib.R | 331 +++++++++++++++--- R/pkg/inst/tests/testthat/test_mllib.R | 68 ++++ .../spark/ml/r/GBTClassificationWrapper.scala | 164 +++++++++ .../spark/ml/r/GBTRegressionWrapper.scala | 144 ++++++++ .../org/apache/spark/ml/r/RWrappers.scala | 4 + .../r/RandomForestClassificationWrapper.scala | 14 +- .../ml/r/RandomForestRegressionWrapper.scala | 14 +- python/pyspark/ml/regression.py | 10 +- 10 files changed, 696 insertions(+), 66 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9cd6269f9a..daee09de88 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -45,7 +45,8 @@ exportMethods("glm", "spark.als", "spark.kstest", "spark.logit", - "spark.randomForest") + "spark.randomForest", + "spark.gbt") # Job group lifecycle management methods export("setJobGroup", @@ -353,7 +354,9 @@ export("as.DataFrame", "read.ml", "print.summary.KSTest", "print.summary.RandomForestRegressionModel", - "print.summary.RandomForestClassificationModel") + "print.summary.RandomForestClassificationModel", + "print.summary.GBTRegressionModel", + "print.summary.GBTClassificationModel") export("structField", "structField.jobj", @@ -380,6 +383,8 @@ S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) S3method(print, summary.RandomForestRegressionModel) S3method(print, summary.RandomForestClassificationModel) +S3method(print, summary.GBTRegressionModel) +S3method(print, summary.GBTClassificationModel) S3method(structField, character) S3method(structField, jobj) S3method(structType, jobj) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0271b26a10..7653ca7bcc 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1343,6 +1343,10 @@ setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) setGeneric("spark.gaussianMixture", function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) +#' @rdname spark.gbt +#' @export +setGeneric("spark.gbt", function(data, formula, ...) { standardGeneric("spark.gbt") }) + #' @rdname spark.glm #' @export setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 7a220b8d53..1065b4b37d 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -116,6 +116,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj")) #' @note RandomForestClassificationModel since 2.1.0 setClass("RandomForestClassificationModel", representation(jobj = "jobj")) +#' S4 class that represents a GBTRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala GBTRegressionModel +#' @export +#' @note GBTRegressionModel since 2.1.0 +setClass("GBTRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a GBTClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala GBTClassificationModel +#' @export +#' @note GBTClassificationModel since 2.1.0 +setClass("GBTClassificationModel", representation(jobj = "jobj")) + #' Saves the MLlib model to the input path #' #' Saves the MLlib model to the input path. For more information, see the specific @@ -124,7 +138,8 @@ setClass("RandomForestClassificationModel", representation(jobj = "jobj")) #' @name write.ml #' @export #' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, +#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, #' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, #' @seealso \link{spark.randomForest}, \link{spark.survreg}, #' @seealso \link{read.ml} @@ -138,7 +153,8 @@ NULL #' @name predict #' @export #' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, +#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, #' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, #' @seealso \link{spark.randomForest}, \link{spark.survreg} NULL @@ -634,7 +650,7 @@ setMethod("fitted", signature(object = "KMeansModel"), # Get the summary of a k-means model #' @param object a fitted k-means model. -#' @return \code{summary} returns the model's coefficients, size and cluster. +#' @return \code{summary} returns the model's features, coefficients, k, size and cluster. #' @rdname spark.kmeans #' @export #' @note summary(KMeansModel) since 2.0.0 @@ -679,15 +695,15 @@ setMethod("predict", signature(object = "KMeansModel"), #' @param data SparkDataFrame for training #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param regParam the regularization parameter. Default is 0.0. +#' @param regParam the regularization parameter. #' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. #' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination #' of L1 and L2. Default is 0.0 which is an L2 penalty. #' @param maxIter maximum iteration number. #' @param tol convergence tolerance of iterations. -#' @param fitIntercept whether to fit an intercept term. Default is TRUE. +#' @param fitIntercept whether to fit an intercept term. #' @param family the name of family which is a description of the label distribution to be used in the model. -#' Supported options: Default is "auto". +#' Supported options: #' \itemize{ #' \item{"auto": Automatically select the family based on the number of classes: #' If number of classes == 1 || number of classes == 2, set to "binomial". @@ -705,11 +721,11 @@ setMethod("predict", signature(object = "KMeansModel"), #' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of #' predicting each class. Array must have length equal to the number of classes, with values > 0, #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p -#' is the original probability of that class and t is the class's threshold. Default is 0.5. +#' is the original probability of that class and t is the class's threshold. #' @param weightCol The weight column name. #' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions -#' are large, this param could be adjusted to a larger size. Default is 2. -#' @param probabilityCol column name for predicted class conditional probabilities. Default is "probability". +#' are large, this param could be adjusted to a larger size. +#' @param probabilityCol column name for predicted class conditional probabilities. #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model #' @rdname spark.logit @@ -791,8 +807,10 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), # Get the summary of an LogisticRegressionModel #' @param object an LogisticRegressionModel fitted by \code{spark.logit} -#' @return \code{summary} returns the Binary Logistic regression results of a given model as lists. Note that -#' Multinomial logistic regression summary is not available now. +#' @return \code{summary} returns the Binary Logistic regression results of a given model as list, +#' including roc, areaUnderROC, pr, fMeasureByThreshold, precisionByThreshold, +#' recallByThreshold, totalIterations, objectiveHistory. Note that Multinomial logistic +#' regression summary is not available now. #' @rdname spark.logit #' @aliases summary,LogisticRegressionModel-method #' @export @@ -1141,6 +1159,10 @@ read.ml <- function(path) { new("RandomForestRegressionModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { new("RandomForestClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) { + new("GBTRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) { + new("GBTClassificationModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } @@ -1196,13 +1218,13 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' data and \code{write.ml}/\code{read.ml} to save/load fitted models. #' #' @param data A SparkDataFrame for training -#' @param features Features column name, default "features". Either libSVM-format column or -#' character-format column is valid. -#' @param k Number of topics, default 10 -#' @param maxIter Maximum iterations, default 20 -#' @param optimizer Optimizer to train an LDA model, "online" or "em", default "online" +#' @param features Features column name. Either libSVM-format column or character-format column is +#' valid. +#' @param k Number of topics. +#' @param maxIter Maximum iterations. +#' @param optimizer Optimizer to train an LDA model, "online" or "em", default is "online". #' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in -#' each iteration of mini-batch gradient descent, in range (0, 1], default 0.05 +#' each iteration of mini-batch gradient descent, in range (0, 1]. #' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for #' the prior placed on topic distributions over terms, default -1 to set automatically on the #' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size @@ -1263,7 +1285,7 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), # similarly to R's summary(). #' @param object a fitted AFT survival regression model. -#' @return \code{summary} returns a list containing the model's coefficients, +#' @return \code{summary} returns a list containing the model's features, coefficients, #' intercept and log(scale) #' @rdname spark.survreg #' @export @@ -1351,7 +1373,7 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = # Get the summary of a multivariate gaussian mixture model #' @param object a fitted gaussian mixture model. -#' @return \code{summary} returns the model's lambda, mu, sigma and posterior. +#' @return \code{summary} returns the model's lambda, mu, sigma, k, dim and posterior. #' @aliases spark.gaussianMixture,SparkDataFrame,formula-method #' @rdname spark.gaussianMixture #' @export @@ -1644,33 +1666,38 @@ print.summary.KSTest <- function(x, ...) { #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to #' save/load fitted models. #' For more details, see -#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{Random Forest} +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-regression}{ +#' Random Forest Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier}{ +#' Random Forest Classification} #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. #' @param type type of model, one of "regression" or "classification", to fit -#' @param maxDepth Maximum depth of the tree (>= 0). (default = 5) +#' @param maxDepth Maximum depth of the tree (>= 0). #' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing #' how to split on features at each node. More bins give higher granularity. Must be -#' >= 2 and >= number of categories in any categorical feature. (default = 32) +#' >= 2 and >= number of categories in any categorical feature. #' @param numTrees Number of trees to train (>= 1). #' @param impurity Criterion used for information gain calculation. #' For regression, must be "variance". For classification, must be one of -#' "entropy" and "gini". (default = gini) -#' @param minInstancesPerNode Minimum number of instances each child must have after split. -#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. -#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' "entropy" and "gini", default is "gini". #' @param featureSubsetStrategy The number of features to consider for splits at each tree node. #' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. #' @param seed integer seed for random number generation. #' @param subsamplingRate Fraction of the training data used for learning each decision tree, in -#' range (0, 1]. (default = 1.0) -#' @param probabilityCol column name for predicted class conditional probabilities, only for -#' classification. (default = "probability") +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with -#' nodes. +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param probabilityCol column name for predicted class conditional probabilities, only for +#' classification. #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -1703,9 +1730,9 @@ print.summary.KSTest <- function(x, ...) { setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, type = c("regression", "classification"), maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, - minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, - probabilityCol = "probability", maxMemoryInMB = 256, cacheNodeIds = FALSE) { + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + maxMemoryInMB = 256, cacheNodeIds = FALSE, probabilityCol = "probability") { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -1749,7 +1776,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' @rdname spark.randomForest #' @aliases predict,RandomForestRegressionModel-method #' @export -#' @note predict(randomForestRegressionModel) since 2.1.0 +#' @note predict(RandomForestRegressionModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestRegressionModel"), function(object, newData) { predict_internal(object, newData) @@ -1758,7 +1785,7 @@ setMethod("predict", signature(object = "RandomForestRegressionModel"), #' @rdname spark.randomForest #' @aliases predict,RandomForestClassificationModel-method #' @export -#' @note predict(randomForestClassificationModel) since 2.1.0 +#' @note predict(RandomForestClassificationModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestClassificationModel"), function(object, newData) { predict_internal(object, newData) @@ -1789,8 +1816,8 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path write_internal(object, path, overwrite) }) -# Get the summary of an RandomForestRegressionModel model -summary.randomForest <- function(model) { +# Create the summary of a tree ensemble model (eg. Random Forest, GBT) +summary.treeEnsemble <- function(model) { jobj <- model@jobj formula <- callJMethod(jobj, "formula") numFeatures <- callJMethod(jobj, "numFeatures") @@ -1807,20 +1834,23 @@ summary.randomForest <- function(model) { jobj = jobj) } -#' @return \code{summary} returns the model's features as lists, depth and number of nodes -#' or number of classes. +# Get the summary of a Random Forest Regression Model + +#' @return \code{summary} returns a summary object of the fitted model, a list of components +#' including formula, number of features, list of features, feature importances, number of +#' trees, and tree weights #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method #' @export #' @note summary(RandomForestRegressionModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestRegressionModel"), function(object) { - ans <- summary.randomForest(object) + ans <- summary.treeEnsemble(object) class(ans) <- "summary.RandomForestRegressionModel" ans }) -# Get the summary of an RandomForestClassificationModel model +# Get the summary of a Random Forest Classification Model #' @rdname spark.randomForest #' @aliases summary,RandomForestClassificationModel-method @@ -1828,13 +1858,13 @@ setMethod("summary", signature(object = "RandomForestRegressionModel"), #' @note summary(RandomForestClassificationModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestClassificationModel"), function(object) { - ans <- summary.randomForest(object) + ans <- summary.treeEnsemble(object) class(ans) <- "summary.RandomForestClassificationModel" ans }) -# Prints the summary of Random Forest Regression Model -print.summary.randomForest <- function(x) { +# Prints the summary of tree ensemble models (eg. Random Forest, GBT) +print.summary.treeEnsemble <- function(x) { jobj <- x$jobj cat("Formula: ", x$formula) cat("\nNumber of features: ", x$numFeatures) @@ -1848,13 +1878,15 @@ print.summary.randomForest <- function(x) { invisible(x) } +# Prints the summary of Random Forest Regression Model + #' @param x summary object of Random Forest regression model or classification model #' returned by \code{summary}. #' @rdname spark.randomForest #' @export #' @note print.summary.RandomForestRegressionModel since 2.1.0 print.summary.RandomForestRegressionModel <- function(x, ...) { - print.summary.randomForest(x) + print.summary.treeEnsemble(x) } # Prints the summary of Random Forest Classification Model @@ -1863,5 +1895,214 @@ print.summary.RandomForestRegressionModel <- function(x, ...) { #' @export #' @note print.summary.RandomForestClassificationModel since 2.1.0 print.summary.RandomForestClassificationModel <- function(x, ...) { - print.summary.randomForest(x) + print.summary.treeEnsemble(x) +} + +#' Gradient Boosted Tree Model for Regression and Classification +#' +#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a +#' SparkDataFrame. Users can call \code{summary} to get a summary of the fitted +#' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and +#' \code{write.ml}/\code{read.ml} to save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{ +#' GBT Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{ +#' GBT Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param maxIter Param for maximum number of iterations (>= 0). +#' @param stepSize Param for Step size to be used for each iteration of optimization. +#' @param lossType Loss function which GBT tries to minimize. +#' For classification, must be "logistic". For regression, must be one of +#' "squared" (L2) and "absolute" (L1), default is "squared". +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. If a +#' split causes the left or right child to have fewer than +#' minInstancesPerNode, the split will be discarded as invalid. Should be +#' >= 1. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param ... additional arguments passed to the method. +#' @aliases spark.gbt,SparkDataFrame,formula-method +#' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. +#' @rdname spark.gbt +#' @name spark.gbt +#' @export +#' @examples +#' \dontrun{ +#' # fit a Gradient Boosted Tree Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Gradient Boosted Tree Classification Model +#' # label must be binary - Only binary classification is supported for GBT. +#' df <- createDataFrame(iris[iris$Species != "virginica", ]) +#' model <- spark.gbt(df, Species ~ Petal_Length + Petal_Width, "classification") +#' +#' # numeric label is also supported +#' iris2 <- iris[iris$Species != "virginica", ] +#' iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) +#' df <- createDataFrame(iris2) +#' model <- spark.gbt(df, NumericSpecies ~ ., type = "classification") +#' } +#' @note spark.gbt since 2.1.0 +setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL, + seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, + checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(lossType)) lossType <- "squared" + lossType <- match.arg(lossType, c("squared", "absolute")) + jobj <- callJStatic("org.apache.spark.ml.r.GBTRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(lossType)) lossType <- "logistic" + lossType <- match.arg(lossType, "logistic") + jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTClassificationModel", jobj = jobj) + } + ) + }) + +# Makes predictions from a Gradient Boosted Tree Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction" +#' @rdname spark.gbt +#' @aliases predict,GBTRegressionModel-method +#' @export +#' @note predict(GBTRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "GBTRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.gbt +#' @aliases predict,GBTClassificationModel-method +#' @export +#' @note predict(GBTClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "GBTClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Gradient Boosted Tree Regression or Classification model to the input path. + +#' @param object A fitted Gradient Boosted Tree regression model or classification model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' @aliases write.ml,GBTRegressionModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,GBTClassificationModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +# Get the summary of a Gradient Boosted Tree Regression Model + +#' @return \code{summary} returns a summary object of the fitted model, a list of components +#' including formula, number of features, list of features, feature importances, number of +#' trees, and tree weights +#' @rdname spark.gbt +#' @aliases summary,GBTRegressionModel-method +#' @export +#' @note summary(GBTRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "GBTRegressionModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTRegressionModel" + ans + }) + +# Get the summary of a Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @aliases summary,GBTClassificationModel-method +#' @export +#' @note summary(GBTClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "GBTClassificationModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTClassificationModel" + ans + }) + +# Prints the summary of Gradient Boosted Tree Regression Model + +#' @param x summary object of Gradient Boosted Tree regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTRegressionModel since 2.1.0 +print.summary.GBTRegressionModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Prints the summary of Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTClassificationModel since 2.1.0 +print.summary.GBTClassificationModel <- function(x, ...) { + print.summary.treeEnsemble(x) } diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 5f742d9045..33e9d0d267 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -949,4 +949,72 @@ test_that("spark.randomForest Classification", { unlink(modelPath) }) +test_that("spark.gbt", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + expect_equal(stats$formula, "Employed ~ .") + expect_equal(stats$numFeatures, 6) + expect_equal(length(stats$treeWeights), 20) + + modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + + # classification + # label must be binary - GBTClassifier currently only supports binary classification. + iris2 <- iris[iris$Species != "virginica", ] + data <- suppressWarnings(createDataFrame(iris2)) + model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification") + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + predictions <- collect(predict(model, data))$prediction + # test string prediction values + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + + iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) + df <- suppressWarnings(createDataFrame(iris2)) + m <- spark.gbt(df, NumericSpecies ~ ., type = "classification") + s <- summary(m) + # test numeric prediction values + expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) + expect_equal(s$numFeatures, 5) + expect_equal(s$numTrees, 20) +}) + sparkR.session.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala new file mode 100644 index 0000000000..8946025032 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} +import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} +import org.apache.spark.ml.feature.{IndexToString, RFormula} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class GBTClassifierWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + import GBTClassifierWrapper._ + + private val gbtcModel: GBTClassificationModel = + pipeline.stages(1).asInstanceOf[GBTClassificationModel] + + lazy val numFeatures: Int = gbtcModel.numFeatures + lazy val featureImportances: Vector = gbtcModel.featureImportances + lazy val numTrees: Int = gbtcModel.getNumTrees + lazy val treeWeights: Array[Double] = gbtcModel.treeWeights + + def summary: String = gbtcModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(gbtcModel.getFeaturesCol) + } + + override def write: MLWriter = new + GBTClassifierWrapper.GBTClassifierWrapperWriter(this) +} + +private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] { + + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" + val PREDICTED_LABEL_COL = "prediction" + + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + maxIter: Int, + stepSize: Double, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + lossType: String, + seed: String, + subsamplingRate: Double, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): GBTClassifierWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + .setForceIndexLabel(true) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // get label names from output schema + val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) + .asInstanceOf[NominalAttribute] + val labels = labelAttr.values.get + + // assemble and fit the pipeline + val rfc = new GBTClassifier() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setMaxIter(maxIter) + .setStepSize(stepSize) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setLossType(lossType) + .setSubsamplingRate(subsamplingRate) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setFeaturesCol(rFormula.getFeaturesCol) + .setPredictionCol(PREDICTED_LABEL_INDEX_COL) + if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) + + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, rfc, idxToStr)) + .fit(data) + + new GBTClassifierWrapper(pipeline, formula, features) + } + + override def read: MLReader[GBTClassifierWrapper] = new GBTClassifierWrapperReader + + override def load(path: String): GBTClassifierWrapper = super.load(path) + + class GBTClassifierWrapperWriter(instance: GBTClassifierWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class GBTClassifierWrapperReader extends MLReader[GBTClassifierWrapper] { + + override def load(path: String): GBTClassifierWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new GBTClassifierWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala new file mode 100644 index 0000000000..585077588e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class GBTRegressorWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + private val gbtrModel: GBTRegressionModel = + pipeline.stages(1).asInstanceOf[GBTRegressionModel] + + lazy val numFeatures: Int = gbtrModel.numFeatures + lazy val featureImportances: Vector = gbtrModel.featureImportances + lazy val numTrees: Int = gbtrModel.getNumTrees + lazy val treeWeights: Array[Double] = gbtrModel.treeWeights + + def summary: String = gbtrModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(gbtrModel.getFeaturesCol) + } + + override def write: MLWriter = new + GBTRegressorWrapper.GBTRegressorWrapperWriter(this) +} + +private[r] object GBTRegressorWrapper extends MLReadable[GBTRegressorWrapper] { + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + maxIter: Int, + stepSize: Double, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + lossType: String, + seed: String, + subsamplingRate: Double, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): GBTRegressorWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val rfr = new GBTRegressor() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setMaxIter(maxIter) + .setStepSize(stepSize) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setLossType(lossType) + .setSubsamplingRate(subsamplingRate) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setFeaturesCol(rFormula.getFeaturesCol) + if (seed != null && seed.length > 0) rfr.setSeed(seed.toLong) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, rfr)) + .fit(data) + + new GBTRegressorWrapper(pipeline, formula, features) + } + + override def read: MLReader[GBTRegressorWrapper] = new GBTRegressorWrapperReader + + override def load(path: String): GBTRegressorWrapper = super.load(path) + + class GBTRegressorWrapperWriter(instance: GBTRegressorWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class GBTRegressorWrapperReader extends MLReader[GBTRegressorWrapper] { + + override def load(path: String): GBTRegressorWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new GBTRegressorWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 0e09e18027..b59fe29234 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -60,6 +60,10 @@ private[r] object RWrappers extends MLReader[Object] { RandomForestRegressorWrapper.load(path) case "org.apache.spark.ml.r.RandomForestClassifierWrapper" => RandomForestClassifierWrapper.load(path) + case "org.apache.spark.ml.r.GBTRegressorWrapper" => + GBTRegressorWrapper.load(path) + case "org.apache.spark.ml.r.GBTClassifierWrapper" => + GBTClassifierWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index b0088ddaf3..6947ba7e75 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -35,18 +35,18 @@ private[r] class RandomForestClassifierWrapper private ( val formula: String, val features: Array[String]) extends MLWritable { - private val DTModel: RandomForestClassificationModel = + private val rfcModel: RandomForestClassificationModel = pipeline.stages(1).asInstanceOf[RandomForestClassificationModel] - lazy val numFeatures: Int = DTModel.numFeatures - lazy val featureImportances: Vector = DTModel.featureImportances - lazy val numTrees: Int = DTModel.getNumTrees - lazy val treeWeights: Array[Double] = DTModel.treeWeights + lazy val numFeatures: Int = rfcModel.numFeatures + lazy val featureImportances: Vector = rfcModel.featureImportances + lazy val numTrees: Int = rfcModel.getNumTrees + lazy val treeWeights: Array[Double] = rfcModel.treeWeights - def summary: String = DTModel.toDebugString + def summary: String = rfcModel.toDebugString def transform(dataset: Dataset[_]): DataFrame = { - pipeline.transform(dataset).drop(DTModel.getFeaturesCol) + pipeline.transform(dataset).drop(rfcModel.getFeaturesCol) } override def write: MLWriter = new diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala index c8874407fa..4b9a3a731d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala @@ -35,18 +35,18 @@ private[r] class RandomForestRegressorWrapper private ( val formula: String, val features: Array[String]) extends MLWritable { - private val DTModel: RandomForestRegressionModel = + private val rfrModel: RandomForestRegressionModel = pipeline.stages(1).asInstanceOf[RandomForestRegressionModel] - lazy val numFeatures: Int = DTModel.numFeatures - lazy val featureImportances: Vector = DTModel.featureImportances - lazy val numTrees: Int = DTModel.getNumTrees - lazy val treeWeights: Array[Double] = DTModel.treeWeights + lazy val numFeatures: Int = rfrModel.numFeatures + lazy val featureImportances: Vector = rfrModel.featureImportances + lazy val numTrees: Int = rfrModel.getNumTrees + lazy val treeWeights: Array[Double] = rfrModel.treeWeights - def summary: String = DTModel.toDebugString + def summary: String = rfrModel.toDebugString def transform(dataset: Dataset[_]): DataFrame = { - pipeline.transform(dataset).drop(DTModel.getFeaturesCol) + pipeline.transform(dataset).drop(rfrModel.getFeaturesCol) } override def write: MLWriter = new diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 9233d2e7e1..0bc319ca4d 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -828,7 +828,7 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada @inherit_doc class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, RandomForestParams, TreeRegressorParams, HasCheckpointInterval, - JavaMLWritable, JavaMLReadable, HasVarianceCol): + JavaMLWritable, JavaMLReadable): """ `Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_ learning algorithm for regression. @@ -876,13 +876,13 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, - featureSubsetStrategy="auto", varianceCol=None): + featureSubsetStrategy="auto"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \ - featureSubsetStrategy="auto", varianceCol=None) + featureSubsetStrategy="auto") """ super(RandomForestRegressor, self).__init__() self._java_obj = self._new_java_obj( @@ -900,13 +900,13 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, - featureSubsetStrategy="auto", varianceCol=None): + featureSubsetStrategy="auto"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \ - featureSubsetStrategy="auto", varianceCol=None) + featureSubsetStrategy="auto") Sets params for linear regression. """ kwargs = self.setParams._input_kwargs -- GitLab