diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5c074d3c0fd40cff7a623bfedc28af19525a45e6..4e3fe00a2e9bdc936e3acef78c69f3eb1e753b4c 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -63,6 +63,7 @@ exportMethods("glm", "spark.als", "spark.kstest", "spark.logit", + "spark.decisionTree", "spark.randomForest", "spark.gbt", "spark.bisectingKmeans", @@ -414,6 +415,8 @@ export("as.DataFrame", "print.summary.GeneralizedLinearRegressionModel", "read.ml", "print.summary.KSTest", + "print.summary.DecisionTreeRegressionModel", + "print.summary.DecisionTreeClassificationModel", "print.summary.RandomForestRegressionModel", "print.summary.RandomForestClassificationModel", "print.summary.GBTRegressionModel", @@ -452,6 +455,8 @@ S3method(print, structField) S3method(print, structType) S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) +S3method(print, summary.DecisionTreeRegressionModel) +S3method(print, summary.DecisionTreeClassificationModel) S3method(print, summary.RandomForestRegressionModel) S3method(print, summary.RandomForestClassificationModel) S3method(print, summary.GBTRegressionModel) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 514ca99d45cd343fcb5474214bf3d6006f095bae..5630d0c8a0df9e99c7e38718012f1cd13982cb6c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1506,6 +1506,11 @@ setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.ml #' @export setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") }) +#' @rdname spark.decisionTree +#' @export +setGeneric("spark.decisionTree", + function(data, formula, ...) { standardGeneric("spark.decisionTree") }) + #' @rdname spark.randomForest #' @export setGeneric("spark.randomForest", diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 82279be6fbe775ce6cef137657b2bb75a45f3de4..2f1220a752783652085b90f61985584ff2e65ce8 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -45,6 +45,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj")) #' @note RandomForestClassificationModel since 2.1.0 setClass("RandomForestClassificationModel", representation(jobj = "jobj")) +#' S4 class that represents a DecisionTreeRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel +#' @export +#' @note DecisionTreeRegressionModel since 2.3.0 +setClass("DecisionTreeRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a DecisionTreeClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel +#' @export +#' @note DecisionTreeClassificationModel since 2.3.0 +setClass("DecisionTreeClassificationModel", representation(jobj = "jobj")) + # Create the summary of a tree ensemble model (eg. Random Forest, GBT) summary.treeEnsemble <- function(model) { jobj <- model@jobj @@ -81,6 +95,36 @@ print.summary.treeEnsemble <- function(x) { invisible(x) } +# Create the summary of a decision tree model +summary.decisionTree <- function(model) { + jobj <- model@jobj + formula <- callJMethod(jobj, "formula") + numFeatures <- callJMethod(jobj, "numFeatures") + features <- callJMethod(jobj, "features") + featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + maxDepth <- callJMethod(jobj, "maxDepth") + list(formula = formula, + numFeatures = numFeatures, + features = features, + featureImportances = featureImportances, + maxDepth = maxDepth, + jobj = jobj) +} + +# Prints the summary of decision tree models +print.summary.decisionTree <- function(x) { + jobj <- x$jobj + cat("Formula: ", x$formula) + cat("\nNumber of features: ", x$numFeatures) + cat("\nFeatures: ", unlist(x$features)) + cat("\nFeature importances: ", x$featureImportances) + cat("\nMax Depth: ", x$maxDepth) + + summaryStr <- callJMethod(jobj, "summary") + cat("\n", summaryStr, "\n") + invisible(x) +} + #' Gradient Boosted Tree Model for Regression and Classification #' #' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a @@ -499,3 +543,199 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path function(object, path, overwrite = FALSE) { write_internal(object, path, overwrite) }) + +#' Decision Tree Model for Regression and Classification +#' +#' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree +#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{ +#' Decision Tree Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{ +#' Decision Tree Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param impurity Criterion used for information gain calculation. +#' For regression, must be "variance". For classification, must be one of +#' "entropy" and "gini", default is "gini". +#' @param seed integer seed for random number generation. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param ... additional arguments passed to the method. +#' @aliases spark.decisionTree,SparkDataFrame,formula-method +#' @return \code{spark.decisionTree} returns a fitted Decision Tree model. +#' @rdname spark.decisionTree +#' @name spark.decisionTree +#' @export +#' @examples +#' \dontrun{ +#' # fit a Decision Tree Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.decisionTree(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Decision Tree Classification Model +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.decisionTree(df, Survived ~ Freq + Age, "classification") +#' } +#' @note spark.decisionTree since 2.3.0 +setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL, + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(impurity)) impurity <- "variance" + impurity <- match.arg(impurity, "variance") + jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), impurity, + as.integer(minInstancesPerNode), as.numeric(minInfoGain), + as.integer(checkpointInterval), seed, + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("DecisionTreeRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(impurity)) impurity <- "gini" + impurity <- match.arg(impurity, c("gini", "entropy")) + jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), impurity, + as.integer(minInstancesPerNode), as.numeric(minInfoGain), + as.integer(checkpointInterval), seed, + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("DecisionTreeClassificationModel", jobj = jobj) + } + ) + }) + +# Get the summary of a Decision Tree Regression Model + +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes \code{formula} (formula), +#' \code{numFeatures} (number of features), \code{features} (list of features), +#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees). +#' @rdname spark.decisionTree +#' @aliases summary,DecisionTreeRegressionModel-method +#' @export +#' @note summary(DecisionTreeRegressionModel) since 2.3.0 +setMethod("summary", signature(object = "DecisionTreeRegressionModel"), + function(object) { + ans <- summary.decisionTree(object) + class(ans) <- "summary.DecisionTreeRegressionModel" + ans + }) + +# Prints the summary of Decision Tree Regression Model + +#' @param x summary object of Decision Tree regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.decisionTree +#' @export +#' @note print.summary.DecisionTreeRegressionModel since 2.3.0 +print.summary.DecisionTreeRegressionModel <- function(x, ...) { + print.summary.decisionTree(x) +} + +# Get the summary of a Decision Tree Classification Model + +#' @rdname spark.decisionTree +#' @aliases summary,DecisionTreeClassificationModel-method +#' @export +#' @note summary(DecisionTreeClassificationModel) since 2.3.0 +setMethod("summary", signature(object = "DecisionTreeClassificationModel"), + function(object) { + ans <- summary.decisionTree(object) + class(ans) <- "summary.DecisionTreeClassificationModel" + ans + }) + +# Prints the summary of Decision Tree Classification Model + +#' @rdname spark.decisionTree +#' @export +#' @note print.summary.DecisionTreeClassificationModel since 2.3.0 +print.summary.DecisionTreeClassificationModel <- function(x, ...) { + print.summary.decisionTree(x) +} + +# Makes predictions from a Decision Tree Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction". +#' @rdname spark.decisionTree +#' @aliases predict,DecisionTreeRegressionModel-method +#' @export +#' @note predict(DecisionTreeRegressionModel) since 2.3.0 +setMethod("predict", signature(object = "DecisionTreeRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.decisionTree +#' @aliases predict,DecisionTreeClassificationModel-method +#' @export +#' @note predict(DecisionTreeClassificationModel) since 2.3.0 +setMethod("predict", signature(object = "DecisionTreeClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Decision Tree Regression or Classification model to the input path. + +#' @param object A fitted Decision Tree regression model or classification model. +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,DecisionTreeRegressionModel,character-method +#' @rdname spark.decisionTree +#' @export +#' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0 +setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,DecisionTreeClassificationModel,character-method +#' @rdname spark.decisionTree +#' @export +#' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0 +setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index 5dfef8625061bbcaecf40a35117edc3f52134880..a53c92c2c4815e26df274364fc248f468d063535 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -32,8 +32,9 @@ #' @rdname write.ml #' @name write.ml #' @export -#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture}, -#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg}, +#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, +#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, +#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, #' @seealso \link{spark.kmeans}, #' @seealso \link{spark.lda}, \link{spark.logit}, #' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, @@ -48,8 +49,9 @@ NULL #' @rdname predict #' @name predict #' @export -#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture}, -#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg}, +#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, +#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, +#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, #' @seealso \link{spark.kmeans}, #' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, #' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear} @@ -110,6 +112,10 @@ read.ml <- function(path) { new("RandomForestRegressionModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { new("RandomForestClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) { + new("DecisionTreeRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) { + new("DecisionTreeClassificationModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) { new("GBTRegressionModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) { diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R index 146bc2878e2632fbf4e307a21b932fbe7db46383..b283e734cec53f2da21d034f1db91468c510a91c 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R @@ -209,4 +209,90 @@ test_that("spark.randomForest", { expect_equal(summary(model)$numFeatures, 4) }) +test_that("spark.decisionTree", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16) + + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + + stats <- summary(model) + expect_equal(stats$maxDepth, 5) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) + + unlink(modelPath) + + # classification + data <- suppressWarnings(createDataFrame(iris)) + model <- spark.decisionTree(data, Species ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$maxDepth, 5) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.decisionTree(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$maxDepth, 5) + + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) + + # spark.decisionTree classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.decisionTree(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) +}) + sparkR.session.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala new file mode 100644 index 0000000000000000000000000000000000000000..7f59825504d8e1eae88b4cf44585ae3960d23c41 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} +import org.apache.spark.ml.feature.{IndexToString, RFormula} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.r.RWrapperUtils._ +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class DecisionTreeClassifierWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + import DecisionTreeClassifierWrapper._ + + private val dtcModel: DecisionTreeClassificationModel = + pipeline.stages(1).asInstanceOf[DecisionTreeClassificationModel] + + lazy val numFeatures: Int = dtcModel.numFeatures + lazy val featureImportances: Vector = dtcModel.featureImportances + lazy val maxDepth: Int = dtcModel.getMaxDepth + + def summary: String = dtcModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(dtcModel.getFeaturesCol) + .drop(dtcModel.getLabelCol) + } + + override def write: MLWriter = new + DecisionTreeClassifierWrapper.DecisionTreeClassifierWrapperWriter(this) +} + +private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeClassifierWrapper] { + + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" + val PREDICTED_LABEL_COL = "prediction" + + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + impurity: String, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + seed: String, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): DecisionTreeClassifierWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + .setForceIndexLabel(true) + checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get labels and feature names from output schema + val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) + + // assemble and fit the pipeline + val dtc = new DecisionTreeClassifier() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setImpurity(impurity) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) + .setPredictionCol(PREDICTED_LABEL_INDEX_COL) + if (seed != null && seed.length > 0) dtc.setSeed(seed.toLong) + + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, dtc, idxToStr)) + .fit(data) + + new DecisionTreeClassifierWrapper(pipeline, formula, features) + } + + override def read: MLReader[DecisionTreeClassifierWrapper] = + new DecisionTreeClassifierWrapperReader + + override def load(path: String): DecisionTreeClassifierWrapper = super.load(path) + + class DecisionTreeClassifierWrapperWriter(instance: DecisionTreeClassifierWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class DecisionTreeClassifierWrapperReader extends MLReader[DecisionTreeClassifierWrapper] { + + override def load(path: String): DecisionTreeClassifierWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new DecisionTreeClassifierWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala new file mode 100644 index 0000000000000000000000000000000000000000..de712d67e6df55028f4e1939608545d5baa1b4df --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class DecisionTreeRegressorWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + private val dtrModel: DecisionTreeRegressionModel = + pipeline.stages(1).asInstanceOf[DecisionTreeRegressionModel] + + lazy val numFeatures: Int = dtrModel.numFeatures + lazy val featureImportances: Vector = dtrModel.featureImportances + lazy val maxDepth: Int = dtrModel.getMaxDepth + + def summary: String = dtrModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(dtrModel.getFeaturesCol) + } + + override def write: MLWriter = new + DecisionTreeRegressorWrapper.DecisionTreeRegressorWrapperWriter(this) +} + +private[r] object DecisionTreeRegressorWrapper extends MLReadable[DecisionTreeRegressorWrapper] { + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + impurity: String, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + seed: String, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): DecisionTreeRegressorWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val dtr = new DecisionTreeRegressor() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setImpurity(impurity) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setFeaturesCol(rFormula.getFeaturesCol) + if (seed != null && seed.length > 0) dtr.setSeed(seed.toLong) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, dtr)) + .fit(data) + + new DecisionTreeRegressorWrapper(pipeline, formula, features) + } + + override def read: MLReader[DecisionTreeRegressorWrapper] = new DecisionTreeRegressorWrapperReader + + override def load(path: String): DecisionTreeRegressorWrapper = super.load(path) + + class DecisionTreeRegressorWrapperWriter(instance: DecisionTreeRegressorWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class DecisionTreeRegressorWrapperReader extends MLReader[DecisionTreeRegressorWrapper] { + + override def load(path: String): DecisionTreeRegressorWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new DecisionTreeRegressorWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index b30ce12bc6cc8c30d3124169010774314399e288..ba6445a7303063dfd452d9a56658c4db642b916b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -60,6 +60,10 @@ private[r] object RWrappers extends MLReader[Object] { RandomForestRegressorWrapper.load(path) case "org.apache.spark.ml.r.RandomForestClassifierWrapper" => RandomForestClassifierWrapper.load(path) + case "org.apache.spark.ml.r.DecisionTreeRegressorWrapper" => + DecisionTreeRegressorWrapper.load(path) + case "org.apache.spark.ml.r.DecisionTreeClassifierWrapper" => + DecisionTreeClassifierWrapper.load(path) case "org.apache.spark.ml.r.GBTRegressorWrapper" => GBTRegressorWrapper.load(path) case "org.apache.spark.ml.r.GBTClassifierWrapper" =>