diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7a89c01fee735c1b5d0332c8acc07aec87e1062d..9cd6269f9a8f7b648f19d26e6e3f36fa1d13f7c7 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -44,7 +44,8 @@ exportMethods("glm", "spark.gaussianMixture", "spark.als", "spark.kstest", - "spark.logit") + "spark.logit", + "spark.randomForest") # Job group lifecycle management methods export("setJobGroup", @@ -350,7 +351,9 @@ export("as.DataFrame", "uncacheTable", "print.summary.GeneralizedLinearRegressionModel", "read.ml", - "print.summary.KSTest") + "print.summary.KSTest", + "print.summary.RandomForestRegressionModel", + "print.summary.RandomForestClassificationModel") export("structField", "structField.jobj", @@ -375,6 +378,8 @@ S3method(print, structField) S3method(print, structType) S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) +S3method(print, summary.RandomForestRegressionModel) +S3method(print, summary.RandomForestClassificationModel) S3method(structField, character) S3method(structField, jobj) S3method(structType, jobj) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 107e1c638be710c0813ef1e3d230d3168c41f653..0271b26a10a90a21f46825344aa7daa057a48f6b 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1310,9 +1310,11 @@ setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @export setGeneric("year", function(x) { standardGeneric("year") }) -#' @rdname spark.glm +###################### Spark.ML Methods ########################## + +#' @rdname fitted #' @export -setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) +setGeneric("fitted") #' @param x,y For \code{glm}: logical values indicating whether the response vector #' and model matrix used in the fitting process should be returned as @@ -1332,13 +1334,38 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @export setGeneric("rbind", signature = "...") +#' @rdname spark.als +#' @export +setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) + +#' @rdname spark.gaussianMixture +#' @export +setGeneric("spark.gaussianMixture", + function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) + +#' @rdname spark.glm +#' @export +setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) + +#' @rdname spark.isoreg +#' @export +setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) + #' @rdname spark.kmeans #' @export setGeneric("spark.kmeans", function(data, formula, ...) { standardGeneric("spark.kmeans") }) -#' @rdname fitted +#' @rdname spark.kstest #' @export -setGeneric("fitted") +setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) + +#' @rdname spark.lda +#' @export +setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) + +#' @rdname spark.logit +#' @export +setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) #' @rdname spark.mlp #' @export @@ -1348,13 +1375,14 @@ setGeneric("spark.mlp", function(data, ...) { standardGeneric("spark.mlp") }) #' @export setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") }) -#' @rdname spark.survreg +#' @rdname spark.randomForest #' @export -setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) +setGeneric("spark.randomForest", + function(data, formula, ...) { standardGeneric("spark.randomForest") }) -#' @rdname spark.lda +#' @rdname spark.survreg #' @export -setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) +setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) #' @rdname spark.lda #' @export @@ -1364,20 +1392,6 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark #' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) -#' @rdname spark.isoreg -#' @export -setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) - -#' @rdname spark.gaussianMixture -#' @export -setGeneric("spark.gaussianMixture", - function(data, formula, ...) { - standardGeneric("spark.gaussianMixture") - }) - -#' @rdname spark.logit -#' @export -setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) #' @param object a fitted ML model object. #' @param path the directory where the model is saved. @@ -1385,11 +1399,3 @@ setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark. #' @rdname write.ml #' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) - -#' @rdname spark.als -#' @export -setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) - -#' @rdname spark.kstest -#' @export -setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 629f284b79f336d9d9eace9c4c724e0334f2a16c..7a220b8d53a2fa765739760d2468eb44e231633b 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -102,6 +102,20 @@ setClass("KSTest", representation(jobj = "jobj")) #' @note LogisticRegressionModel since 2.1.0 setClass("LogisticRegressionModel", representation(jobj = "jobj")) +#' S4 class that represents a RandomForestRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel +#' @export +#' @note RandomForestRegressionModel since 2.1.0 +setClass("RandomForestRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a RandomForestClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel +#' @export +#' @note RandomForestClassificationModel since 2.1.0 +setClass("RandomForestClassificationModel", representation(jobj = "jobj")) + #' Saves the MLlib model to the input path #' #' Saves the MLlib model to the input path. For more information, see the specific @@ -112,7 +126,7 @@ setClass("LogisticRegressionModel", representation(jobj = "jobj")) #' @seealso \link{spark.glm}, \link{glm}, #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, #' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, -#' @seealso \link{spark.survreg} +#' @seealso \link{spark.randomForest}, \link{spark.survreg}, #' @seealso \link{read.ml} NULL @@ -125,7 +139,8 @@ NULL #' @export #' @seealso \link{spark.glm}, \link{glm}, #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, -#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} +#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, +#' @seealso \link{spark.randomForest}, \link{spark.survreg} NULL write_internal <- function(object, path, overwrite = FALSE) { @@ -1122,6 +1137,10 @@ read.ml <- function(path) { new("ALSModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) { new("LogisticRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestRegressorWrapper")) { + new("RandomForestRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { + new("RandomForestClassificationModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } @@ -1617,3 +1636,232 @@ print.summary.KSTest <- function(x, ...) { cat(summaryStr, "\n") invisible(x) } + +#' Random Forest Model for Regression and Classification +#' +#' \code{spark.randomForest} fits a Random Forest Regression model or Classification model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Random Forest +#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{Random Forest} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). (default = 5) +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. (default = 32) +#' @param numTrees Number of trees to train (>= 1). +#' @param impurity Criterion used for information gain calculation. +#' For regression, must be "variance". For classification, must be one of +#' "entropy" and "gini". (default = gini) +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param featureSubsetStrategy The number of features to consider for splits at each tree node. +#' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. (default = 1.0) +#' @param probabilityCol column name for predicted class conditional probabilities, only for +#' classification. (default = "probability") +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. +#' @param ... additional arguments passed to the method. +#' @aliases spark.randomForest,SparkDataFrame,formula-method +#' @return \code{spark.randomForest} returns a fitted Random Forest model. +#' @rdname spark.randomForest +#' @name spark.randomForest +#' @export +#' @examples +#' \dontrun{ +#' # fit a Random Forest Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Random Forest Classification Model +#' df <- createDataFrame(iris) +#' model <- spark.randomForest(df, Species ~ Petal_Length + Petal_Width, "classification") +#' } +#' @note spark.randomForest since 2.1.0 +setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, + probabilityCol = "probability", maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(impurity)) impurity <- "variance" + impurity <- match.arg(impurity, "variance") + jobj <- callJStatic("org.apache.spark.ml.r.RandomForestRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(numTrees), + impurity, as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + as.character(featureSubsetStrategy), seed, + as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("RandomForestRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(impurity)) impurity <- "gini" + impurity <- match.arg(impurity, c("gini", "entropy")) + jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(numTrees), + impurity, as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + as.character(featureSubsetStrategy), seed, + as.numeric(subsamplingRate), as.character(probabilityCol), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("RandomForestClassificationModel", jobj = jobj) + } + ) + }) + +# Makes predictions from a Random Forest Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction" +#' @rdname spark.randomForest +#' @aliases predict,RandomForestRegressionModel-method +#' @export +#' @note predict(randomForestRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "RandomForestRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.randomForest +#' @aliases predict,RandomForestClassificationModel-method +#' @export +#' @note predict(randomForestClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "RandomForestClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Random Forest Regression or Classification model to the input path. + +#' @param object A fitted Random Forest regression model or classification model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,RandomForestRegressionModel,character-method +#' @rdname spark.randomForest +#' @export +#' @note write.ml(RandomForestRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,RandomForestClassificationModel,character-method +#' @rdname spark.randomForest +#' @export +#' @note write.ml(RandomForestClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +# Get the summary of an RandomForestRegressionModel model +summary.randomForest <- function(model) { + jobj <- model@jobj + formula <- callJMethod(jobj, "formula") + numFeatures <- callJMethod(jobj, "numFeatures") + features <- callJMethod(jobj, "features") + featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + numTrees <- callJMethod(jobj, "numTrees") + treeWeights <- callJMethod(jobj, "treeWeights") + list(formula = formula, + numFeatures = numFeatures, + features = features, + featureImportances = featureImportances, + numTrees = numTrees, + treeWeights = treeWeights, + jobj = jobj) +} + +#' @return \code{summary} returns the model's features as lists, depth and number of nodes +#' or number of classes. +#' @rdname spark.randomForest +#' @aliases summary,RandomForestRegressionModel-method +#' @export +#' @note summary(RandomForestRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "RandomForestRegressionModel"), + function(object) { + ans <- summary.randomForest(object) + class(ans) <- "summary.RandomForestRegressionModel" + ans + }) + +# Get the summary of an RandomForestClassificationModel model + +#' @rdname spark.randomForest +#' @aliases summary,RandomForestClassificationModel-method +#' @export +#' @note summary(RandomForestClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "RandomForestClassificationModel"), + function(object) { + ans <- summary.randomForest(object) + class(ans) <- "summary.RandomForestClassificationModel" + ans + }) + +# Prints the summary of Random Forest Regression Model +print.summary.randomForest <- function(x) { + jobj <- x$jobj + cat("Formula: ", x$formula) + cat("\nNumber of features: ", x$numFeatures) + cat("\nFeatures: ", unlist(x$features)) + cat("\nFeature importances: ", x$featureImportances) + cat("\nNumber of trees: ", x$numTrees) + cat("\nTree weights: ", unlist(x$treeWeights)) + + summaryStr <- callJMethod(jobj, "summary") + cat("\n", summaryStr, "\n") + invisible(x) +} + +#' @param x summary object of Random Forest regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.randomForest +#' @export +#' @note print.summary.RandomForestRegressionModel since 2.1.0 +print.summary.RandomForestRegressionModel <- function(x, ...) { + print.summary.randomForest(x) +} + +# Prints the summary of Random Forest Classification Model + +#' @rdname spark.randomForest +#' @export +#' @note print.summary.RandomForestClassificationModel since 2.1.0 +print.summary.RandomForestClassificationModel <- function(x, ...) { + print.summary.randomForest(x) +} diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 6d1fccc7c05825013e0b4d5fd1a22ed06e85fe0a..db98d0e45547e9fe62c2e9bc1fc9043c8cb73549 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -871,4 +871,72 @@ test_that("spark.kstest", { expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") }) +test_that("spark.randomForest Regression", { + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 1) + + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + + stats <- summary(model) + expect_equal(stats$numTrees, 1) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 20, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258, + 63.736, 64.296, 64.868, 64.300, + 66.709, 67.697, 67.966, 67.252, + 68.866, 69.593, 69.195, 69.658), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + + modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) +}) + +test_that("spark.randomForest Classification", { + data <- suppressWarnings(createDataFrame(iris)) + model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) +}) + sparkR.session.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 1df3662a5822b97055fad8bf7e7935dd74a22863..0e09e18027ca70eace96bc107bce72553080dbcc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -56,6 +56,10 @@ private[r] object RWrappers extends MLReader[Object] { ALSWrapper.load(path) case "org.apache.spark.ml.r.LogisticRegressionWrapper" => LogisticRegressionWrapper.load(path) + case "org.apache.spark.ml.r.RandomForestRegressorWrapper" => + RandomForestRegressorWrapper.load(path) + case "org.apache.spark.ml.r.RandomForestClassifierWrapper" => + RandomForestClassifierWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala new file mode 100644 index 0000000000000000000000000000000000000000..b0088ddaf3b1d6bae6ef0ae082ad92b6dc51fd70 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class RandomForestClassifierWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + private val DTModel: RandomForestClassificationModel = + pipeline.stages(1).asInstanceOf[RandomForestClassificationModel] + + lazy val numFeatures: Int = DTModel.numFeatures + lazy val featureImportances: Vector = DTModel.featureImportances + lazy val numTrees: Int = DTModel.getNumTrees + lazy val treeWeights: Array[Double] = DTModel.treeWeights + + def summary: String = DTModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(DTModel.getFeaturesCol) + } + + override def write: MLWriter = new + RandomForestClassifierWrapper.RandomForestClassifierWrapperWriter(this) +} + +private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper] { + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + numTrees: Int, + impurity: String, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + featureSubsetStrategy: String, + seed: String, + subsamplingRate: Double, + probabilityCol: String, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): RandomForestClassifierWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val rfc = new RandomForestClassifier() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setNumTrees(numTrees) + .setImpurity(impurity) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setFeatureSubsetStrategy(featureSubsetStrategy) + .setSubsamplingRate(subsamplingRate) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setProbabilityCol(probabilityCol) + .setFeaturesCol(rFormula.getFeaturesCol) + if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, rfc)) + .fit(data) + + new RandomForestClassifierWrapper(pipeline, formula, features) + } + + override def read: MLReader[RandomForestClassifierWrapper] = + new RandomForestClassifierWrapperReader + + override def load(path: String): RandomForestClassifierWrapper = super.load(path) + + class RandomForestClassifierWrapperWriter(instance: RandomForestClassifierWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class RandomForestClassifierWrapperReader extends MLReader[RandomForestClassifierWrapper] { + + override def load(path: String): RandomForestClassifierWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new RandomForestClassifierWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala new file mode 100644 index 0000000000000000000000000000000000000000..c8874407fa75e14ef2174a968c85ae3706e11cd5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class RandomForestRegressorWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + private val DTModel: RandomForestRegressionModel = + pipeline.stages(1).asInstanceOf[RandomForestRegressionModel] + + lazy val numFeatures: Int = DTModel.numFeatures + lazy val featureImportances: Vector = DTModel.featureImportances + lazy val numTrees: Int = DTModel.getNumTrees + lazy val treeWeights: Array[Double] = DTModel.treeWeights + + def summary: String = DTModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(DTModel.getFeaturesCol) + } + + override def write: MLWriter = new + RandomForestRegressorWrapper.RandomForestRegressorWrapperWriter(this) +} + +private[r] object RandomForestRegressorWrapper extends MLReadable[RandomForestRegressorWrapper] { + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + numTrees: Int, + impurity: String, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + featureSubsetStrategy: String, + seed: String, + subsamplingRate: Double, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): RandomForestRegressorWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val rfr = new RandomForestRegressor() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setNumTrees(numTrees) + .setImpurity(impurity) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setFeatureSubsetStrategy(featureSubsetStrategy) + .setSubsamplingRate(subsamplingRate) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setFeaturesCol(rFormula.getFeaturesCol) + if (seed != null && seed.length > 0) rfr.setSeed(seed.toLong) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, rfr)) + .fit(data) + + new RandomForestRegressorWrapper(pipeline, formula, features) + } + + override def read: MLReader[RandomForestRegressorWrapper] = new RandomForestRegressorWrapperReader + + override def load(path: String): RandomForestRegressorWrapper = super.load(path) + + class RandomForestRegressorWrapperWriter(instance: RandomForestRegressorWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class RandomForestRegressorWrapperReader extends MLReader[RandomForestRegressorWrapper] { + + override def load(path: String): RandomForestRegressorWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new RandomForestRegressorWrapper(pipeline, formula, features) + } + } +}