Skip to content
Snippets Groups Projects
Commit 4be33758 authored by Zheng RuiFeng's avatar Zheng RuiFeng Committed by Felix Cheung
Browse files

[SPARK-15767][ML][SPARKR] Decision Tree wrapper in SparkR

## What changes were proposed in this pull request?
support decision tree in R

## How was this patch tested?
added tests

Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes #17981 from zhengruifeng/dt_r.
parent 36309110
No related branches found
No related tags found
No related merge requests found
......@@ -63,6 +63,7 @@ exportMethods("glm",
"spark.als",
"spark.kstest",
"spark.logit",
"spark.decisionTree",
"spark.randomForest",
"spark.gbt",
"spark.bisectingKmeans",
......@@ -414,6 +415,8 @@ export("as.DataFrame",
"print.summary.GeneralizedLinearRegressionModel",
"read.ml",
"print.summary.KSTest",
"print.summary.DecisionTreeRegressionModel",
"print.summary.DecisionTreeClassificationModel",
"print.summary.RandomForestRegressionModel",
"print.summary.RandomForestClassificationModel",
"print.summary.GBTRegressionModel",
......@@ -452,6 +455,8 @@ S3method(print, structField)
S3method(print, structType)
S3method(print, summary.GeneralizedLinearRegressionModel)
S3method(print, summary.KSTest)
S3method(print, summary.DecisionTreeRegressionModel)
S3method(print, summary.DecisionTreeClassificationModel)
S3method(print, summary.RandomForestRegressionModel)
S3method(print, summary.RandomForestClassificationModel)
S3method(print, summary.GBTRegressionModel)
......
......@@ -1506,6 +1506,11 @@ setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.ml
#' @export
setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") })
#' @rdname spark.decisionTree
#' @export
setGeneric("spark.decisionTree",
function(data, formula, ...) { standardGeneric("spark.decisionTree") })
#' @rdname spark.randomForest
#' @export
setGeneric("spark.randomForest",
......
......@@ -45,6 +45,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj"))
#' @note RandomForestClassificationModel since 2.1.0
setClass("RandomForestClassificationModel", representation(jobj = "jobj"))
#' S4 class that represents a DecisionTreeRegressionModel
#'
#' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel
#' @export
#' @note DecisionTreeRegressionModel since 2.3.0
setClass("DecisionTreeRegressionModel", representation(jobj = "jobj"))
#' S4 class that represents a DecisionTreeClassificationModel
#'
#' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel
#' @export
#' @note DecisionTreeClassificationModel since 2.3.0
setClass("DecisionTreeClassificationModel", representation(jobj = "jobj"))
# Create the summary of a tree ensemble model (eg. Random Forest, GBT)
summary.treeEnsemble <- function(model) {
jobj <- model@jobj
......@@ -81,6 +95,36 @@ print.summary.treeEnsemble <- function(x) {
invisible(x)
}
# Create the summary of a decision tree model
summary.decisionTree <- function(model) {
jobj <- model@jobj
formula <- callJMethod(jobj, "formula")
numFeatures <- callJMethod(jobj, "numFeatures")
features <- callJMethod(jobj, "features")
featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString")
maxDepth <- callJMethod(jobj, "maxDepth")
list(formula = formula,
numFeatures = numFeatures,
features = features,
featureImportances = featureImportances,
maxDepth = maxDepth,
jobj = jobj)
}
# Prints the summary of decision tree models
print.summary.decisionTree <- function(x) {
jobj <- x$jobj
cat("Formula: ", x$formula)
cat("\nNumber of features: ", x$numFeatures)
cat("\nFeatures: ", unlist(x$features))
cat("\nFeature importances: ", x$featureImportances)
cat("\nMax Depth: ", x$maxDepth)
summaryStr <- callJMethod(jobj, "summary")
cat("\n", summaryStr, "\n")
invisible(x)
}
#' Gradient Boosted Tree Model for Regression and Classification
#'
#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a
......@@ -499,3 +543,199 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})
#' Decision Tree Model for Regression and Classification
#'
#' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on
#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree
#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
#' save/load fitted models.
#' For more details, see
#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{
#' Decision Tree Regression} and
#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{
#' Decision Tree Classification}
#'
#' @param data a SparkDataFrame for training.
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', ':', '+', and '-'.
#' @param type type of model, one of "regression" or "classification", to fit
#' @param maxDepth Maximum depth of the tree (>= 0).
#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing
#' how to split on features at each node. More bins give higher granularity. Must be
#' >= 2 and >= number of categories in any categorical feature.
#' @param impurity Criterion used for information gain calculation.
#' For regression, must be "variance". For classification, must be one of
#' "entropy" and "gini", default is "gini".
#' @param seed integer seed for random number generation.
#' @param minInstancesPerNode Minimum number of instances each child must have after split.
#' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation.
#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
#' can speed up training of deeper trees. Users can set how often should the
#' cache be checkpointed or disable it by setting checkpointInterval.
#' @param ... additional arguments passed to the method.
#' @aliases spark.decisionTree,SparkDataFrame,formula-method
#' @return \code{spark.decisionTree} returns a fitted Decision Tree model.
#' @rdname spark.decisionTree
#' @name spark.decisionTree
#' @export
#' @examples
#' \dontrun{
#' # fit a Decision Tree Regression Model
#' df <- createDataFrame(longley)
#' model <- spark.decisionTree(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16)
#'
#' # get the summary of the model
#' summary(model)
#'
#' # make predictions
#' predictions <- predict(model, df)
#'
#' # save and load the model
#' path <- "path/to/model"
#' write.ml(model, path)
#' savedModel <- read.ml(path)
#' summary(savedModel)
#'
#' # fit a Decision Tree Classification Model
#' t <- as.data.frame(Titanic)
#' df <- createDataFrame(t)
#' model <- spark.decisionTree(df, Survived ~ Freq + Age, "classification")
#' }
#' @note spark.decisionTree since 2.3.0
setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, type = c("regression", "classification"),
maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL,
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
maxMemoryInMB = 256, cacheNodeIds = FALSE) {
type <- match.arg(type)
formula <- paste(deparse(formula), collapse = "")
if (!is.null(seed)) {
seed <- as.character(as.integer(seed))
}
switch(type,
regression = {
if (is.null(impurity)) impurity <- "variance"
impurity <- match.arg(impurity, "variance")
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper",
"fit", data@sdf, formula, as.integer(maxDepth),
as.integer(maxBins), impurity,
as.integer(minInstancesPerNode), as.numeric(minInfoGain),
as.integer(checkpointInterval), seed,
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
new("DecisionTreeRegressionModel", jobj = jobj)
},
classification = {
if (is.null(impurity)) impurity <- "gini"
impurity <- match.arg(impurity, c("gini", "entropy"))
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper",
"fit", data@sdf, formula, as.integer(maxDepth),
as.integer(maxBins), impurity,
as.integer(minInstancesPerNode), as.numeric(minInfoGain),
as.integer(checkpointInterval), seed,
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
new("DecisionTreeClassificationModel", jobj = jobj)
}
)
})
# Get the summary of a Decision Tree Regression Model
#' @return \code{summary} returns summary information of the fitted model, which is a list.
#' The list of components includes \code{formula} (formula),
#' \code{numFeatures} (number of features), \code{features} (list of features),
#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees).
#' @rdname spark.decisionTree
#' @aliases summary,DecisionTreeRegressionModel-method
#' @export
#' @note summary(DecisionTreeRegressionModel) since 2.3.0
setMethod("summary", signature(object = "DecisionTreeRegressionModel"),
function(object) {
ans <- summary.decisionTree(object)
class(ans) <- "summary.DecisionTreeRegressionModel"
ans
})
# Prints the summary of Decision Tree Regression Model
#' @param x summary object of Decision Tree regression model or classification model
#' returned by \code{summary}.
#' @rdname spark.decisionTree
#' @export
#' @note print.summary.DecisionTreeRegressionModel since 2.3.0
print.summary.DecisionTreeRegressionModel <- function(x, ...) {
print.summary.decisionTree(x)
}
# Get the summary of a Decision Tree Classification Model
#' @rdname spark.decisionTree
#' @aliases summary,DecisionTreeClassificationModel-method
#' @export
#' @note summary(DecisionTreeClassificationModel) since 2.3.0
setMethod("summary", signature(object = "DecisionTreeClassificationModel"),
function(object) {
ans <- summary.decisionTree(object)
class(ans) <- "summary.DecisionTreeClassificationModel"
ans
})
# Prints the summary of Decision Tree Classification Model
#' @rdname spark.decisionTree
#' @export
#' @note print.summary.DecisionTreeClassificationModel since 2.3.0
print.summary.DecisionTreeClassificationModel <- function(x, ...) {
print.summary.decisionTree(x)
}
# Makes predictions from a Decision Tree Regression model or Classification model
#' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
#' "prediction".
#' @rdname spark.decisionTree
#' @aliases predict,DecisionTreeRegressionModel-method
#' @export
#' @note predict(DecisionTreeRegressionModel) since 2.3.0
setMethod("predict", signature(object = "DecisionTreeRegressionModel"),
function(object, newData) {
predict_internal(object, newData)
})
#' @rdname spark.decisionTree
#' @aliases predict,DecisionTreeClassificationModel-method
#' @export
#' @note predict(DecisionTreeClassificationModel) since 2.3.0
setMethod("predict", signature(object = "DecisionTreeClassificationModel"),
function(object, newData) {
predict_internal(object, newData)
})
# Save the Decision Tree Regression or Classification model to the input path.
#' @param object A fitted Decision Tree regression model or classification model.
#' @param path The directory where the model is saved.
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @aliases write.ml,DecisionTreeRegressionModel,character-method
#' @rdname spark.decisionTree
#' @export
#' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0
setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})
#' @aliases write.ml,DecisionTreeClassificationModel,character-method
#' @rdname spark.decisionTree
#' @export
#' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0
setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})
......@@ -32,8 +32,9 @@
#' @rdname write.ml
#' @name write.ml
#' @export
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture},
#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg},
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree},
#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt},
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg},
#' @seealso \link{spark.kmeans},
#' @seealso \link{spark.lda}, \link{spark.logit},
#' @seealso \link{spark.mlp}, \link{spark.naiveBayes},
......@@ -48,8 +49,9 @@ NULL
#' @rdname predict
#' @name predict
#' @export
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture},
#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg},
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree},
#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt},
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg},
#' @seealso \link{spark.kmeans},
#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes},
#' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear}
......@@ -110,6 +112,10 @@ read.ml <- function(path) {
new("RandomForestRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) {
new("RandomForestClassificationModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) {
new("DecisionTreeRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) {
new("DecisionTreeClassificationModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) {
new("GBTRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) {
......
......@@ -209,4 +209,90 @@ test_that("spark.randomForest", {
expect_equal(summary(model)$numFeatures, 4)
})
test_that("spark.decisionTree", {
# regression
data <- suppressWarnings(createDataFrame(longley))
model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16)
predictions <- collect(predict(model, data))
expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
63.221, 63.639, 64.989, 63.761,
66.019, 67.857, 68.169, 66.513,
68.655, 69.564, 69.331, 70.551),
tolerance = 1e-4)
stats <- summary(model)
expect_equal(stats$maxDepth, 5)
expect_error(capture.output(stats), NA)
expect_true(length(capture.output(stats)) > 6)
modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = ".tmp")
write.ml(model, modelPath)
expect_error(write.ml(model, modelPath))
write.ml(model, modelPath, overwrite = TRUE)
model2 <- read.ml(modelPath)
stats2 <- summary(model2)
expect_equal(stats$formula, stats2$formula)
expect_equal(stats$numFeatures, stats2$numFeatures)
expect_equal(stats$features, stats2$features)
expect_equal(stats$featureImportances, stats2$featureImportances)
expect_equal(stats$maxDepth, stats2$maxDepth)
unlink(modelPath)
# classification
data <- suppressWarnings(createDataFrame(iris))
model <- spark.decisionTree(data, Species ~ Petal_Length + Petal_Width, "classification",
maxDepth = 5, maxBins = 16)
stats <- summary(model)
expect_equal(stats$numFeatures, 2)
expect_equal(stats$maxDepth, 5)
expect_error(capture.output(stats), NA)
expect_true(length(capture.output(stats)) > 6)
# Test string prediction values
predictions <- collect(predict(model, data))$prediction
expect_equal(length(grep("setosa", predictions)), 50)
expect_equal(length(grep("versicolor", predictions)), 50)
modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext = ".tmp")
write.ml(model, modelPath)
expect_error(write.ml(model, modelPath))
write.ml(model, modelPath, overwrite = TRUE)
model2 <- read.ml(modelPath)
stats2 <- summary(model2)
expect_equal(stats$depth, stats2$depth)
expect_equal(stats$numNodes, stats2$numNodes)
expect_equal(stats$numClasses, stats2$numClasses)
unlink(modelPath)
# Test numeric response variable
labelToIndex <- function(species) {
switch(as.character(species),
setosa = 0.0,
versicolor = 1.0,
virginica = 2.0
)
}
iris$NumericSpecies <- lapply(iris$Species, labelToIndex)
data <- suppressWarnings(createDataFrame(iris[-5]))
model <- spark.decisionTree(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification",
maxDepth = 5, maxBins = 16)
stats <- summary(model)
expect_equal(stats$numFeatures, 2)
expect_equal(stats$maxDepth, 5)
# Test numeric prediction values
predictions <- collect(predict(model, data))$prediction
expect_equal(length(grep("1.0", predictions)), 50)
expect_equal(length(grep("2.0", predictions)), 50)
# spark.decisionTree classification can work on libsvm data
data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
source = "libsvm")
model <- spark.decisionTree(data, label ~ features, "classification")
expect_equal(summary(model)$numFeatures, 4)
})
sparkR.session.stop()
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.r
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.feature.{IndexToString, RFormula}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.r.RWrapperUtils._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
private[r] class DecisionTreeClassifierWrapper private (
val pipeline: PipelineModel,
val formula: String,
val features: Array[String]) extends MLWritable {
import DecisionTreeClassifierWrapper._
private val dtcModel: DecisionTreeClassificationModel =
pipeline.stages(1).asInstanceOf[DecisionTreeClassificationModel]
lazy val numFeatures: Int = dtcModel.numFeatures
lazy val featureImportances: Vector = dtcModel.featureImportances
lazy val maxDepth: Int = dtcModel.getMaxDepth
def summary: String = dtcModel.toDebugString
def transform(dataset: Dataset[_]): DataFrame = {
pipeline.transform(dataset)
.drop(PREDICTED_LABEL_INDEX_COL)
.drop(dtcModel.getFeaturesCol)
.drop(dtcModel.getLabelCol)
}
override def write: MLWriter = new
DecisionTreeClassifierWrapper.DecisionTreeClassifierWrapperWriter(this)
}
private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeClassifierWrapper] {
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
val PREDICTED_LABEL_COL = "prediction"
def fit( // scalastyle:ignore
data: DataFrame,
formula: String,
maxDepth: Int,
maxBins: Int,
impurity: String,
minInstancesPerNode: Int,
minInfoGain: Double,
checkpointInterval: Int,
seed: String,
maxMemoryInMB: Int,
cacheNodeIds: Boolean): DecisionTreeClassifierWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema
val (features, labels) = getFeaturesAndLabels(rFormulaModel, data)
// assemble and fit the pipeline
val dtc = new DecisionTreeClassifier()
.setMaxDepth(maxDepth)
.setMaxBins(maxBins)
.setImpurity(impurity)
.setMinInstancesPerNode(minInstancesPerNode)
.setMinInfoGain(minInfoGain)
.setCheckpointInterval(checkpointInterval)
.setMaxMemoryInMB(maxMemoryInMB)
.setCacheNodeIds(cacheNodeIds)
.setFeaturesCol(rFormula.getFeaturesCol)
.setLabelCol(rFormula.getLabelCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
if (seed != null && seed.length > 0) dtc.setSeed(seed.toLong)
val idxToStr = new IndexToString()
.setInputCol(PREDICTED_LABEL_INDEX_COL)
.setOutputCol(PREDICTED_LABEL_COL)
.setLabels(labels)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, dtc, idxToStr))
.fit(data)
new DecisionTreeClassifierWrapper(pipeline, formula, features)
}
override def read: MLReader[DecisionTreeClassifierWrapper] =
new DecisionTreeClassifierWrapperReader
override def load(path: String): DecisionTreeClassifierWrapper = super.load(path)
class DecisionTreeClassifierWrapperWriter(instance: DecisionTreeClassifierWrapper)
extends MLWriter {
override protected def saveImpl(path: String): Unit = {
val rMetadataPath = new Path(path, "rMetadata").toString
val pipelinePath = new Path(path, "pipeline").toString
val rMetadata = ("class" -> instance.getClass.getName) ~
("formula" -> instance.formula) ~
("features" -> instance.features.toSeq)
val rMetadataJson: String = compact(render(rMetadata))
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
instance.pipeline.save(pipelinePath)
}
}
class DecisionTreeClassifierWrapperReader extends MLReader[DecisionTreeClassifierWrapper] {
override def load(path: String): DecisionTreeClassifierWrapper = {
implicit val format = DefaultFormats
val rMetadataPath = new Path(path, "rMetadata").toString
val pipelinePath = new Path(path, "pipeline").toString
val pipeline = PipelineModel.load(pipelinePath)
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
val rMetadata = parse(rMetadataStr)
val formula = (rMetadata \ "formula").extract[String]
val features = (rMetadata \ "features").extract[Array[String]]
new DecisionTreeClassifierWrapper(pipeline, formula, features)
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.r
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
private[r] class DecisionTreeRegressorWrapper private (
val pipeline: PipelineModel,
val formula: String,
val features: Array[String]) extends MLWritable {
private val dtrModel: DecisionTreeRegressionModel =
pipeline.stages(1).asInstanceOf[DecisionTreeRegressionModel]
lazy val numFeatures: Int = dtrModel.numFeatures
lazy val featureImportances: Vector = dtrModel.featureImportances
lazy val maxDepth: Int = dtrModel.getMaxDepth
def summary: String = dtrModel.toDebugString
def transform(dataset: Dataset[_]): DataFrame = {
pipeline.transform(dataset).drop(dtrModel.getFeaturesCol)
}
override def write: MLWriter = new
DecisionTreeRegressorWrapper.DecisionTreeRegressorWrapperWriter(this)
}
private[r] object DecisionTreeRegressorWrapper extends MLReadable[DecisionTreeRegressorWrapper] {
def fit( // scalastyle:ignore
data: DataFrame,
formula: String,
maxDepth: Int,
maxBins: Int,
impurity: String,
minInstancesPerNode: Int,
minInfoGain: Double,
checkpointInterval: Int,
seed: String,
maxMemoryInMB: Int,
cacheNodeIds: Boolean): DecisionTreeRegressorWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
RWrapperUtils.checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
// get feature names from output schema
val schema = rFormulaModel.transform(data).schema
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
.attributes.get
val features = featureAttrs.map(_.name.get)
// assemble and fit the pipeline
val dtr = new DecisionTreeRegressor()
.setMaxDepth(maxDepth)
.setMaxBins(maxBins)
.setImpurity(impurity)
.setMinInstancesPerNode(minInstancesPerNode)
.setMinInfoGain(minInfoGain)
.setCheckpointInterval(checkpointInterval)
.setMaxMemoryInMB(maxMemoryInMB)
.setCacheNodeIds(cacheNodeIds)
.setFeaturesCol(rFormula.getFeaturesCol)
if (seed != null && seed.length > 0) dtr.setSeed(seed.toLong)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, dtr))
.fit(data)
new DecisionTreeRegressorWrapper(pipeline, formula, features)
}
override def read: MLReader[DecisionTreeRegressorWrapper] = new DecisionTreeRegressorWrapperReader
override def load(path: String): DecisionTreeRegressorWrapper = super.load(path)
class DecisionTreeRegressorWrapperWriter(instance: DecisionTreeRegressorWrapper)
extends MLWriter {
override protected def saveImpl(path: String): Unit = {
val rMetadataPath = new Path(path, "rMetadata").toString
val pipelinePath = new Path(path, "pipeline").toString
val rMetadata = ("class" -> instance.getClass.getName) ~
("formula" -> instance.formula) ~
("features" -> instance.features.toSeq)
val rMetadataJson: String = compact(render(rMetadata))
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
instance.pipeline.save(pipelinePath)
}
}
class DecisionTreeRegressorWrapperReader extends MLReader[DecisionTreeRegressorWrapper] {
override def load(path: String): DecisionTreeRegressorWrapper = {
implicit val format = DefaultFormats
val rMetadataPath = new Path(path, "rMetadata").toString
val pipelinePath = new Path(path, "pipeline").toString
val pipeline = PipelineModel.load(pipelinePath)
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
val rMetadata = parse(rMetadataStr)
val formula = (rMetadata \ "formula").extract[String]
val features = (rMetadata \ "features").extract[Array[String]]
new DecisionTreeRegressorWrapper(pipeline, formula, features)
}
}
}
......@@ -60,6 +60,10 @@ private[r] object RWrappers extends MLReader[Object] {
RandomForestRegressorWrapper.load(path)
case "org.apache.spark.ml.r.RandomForestClassifierWrapper" =>
RandomForestClassifierWrapper.load(path)
case "org.apache.spark.ml.r.DecisionTreeRegressorWrapper" =>
DecisionTreeRegressorWrapper.load(path)
case "org.apache.spark.ml.r.DecisionTreeClassifierWrapper" =>
DecisionTreeClassifierWrapper.load(path)
case "org.apache.spark.ml.r.GBTRegressorWrapper" =>
GBTRegressorWrapper.load(path)
case "org.apache.spark.ml.r.GBTClassifierWrapper" =>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment