Skip to content
Snippets Groups Projects
Commit a97c4970 authored by Zheng RuiFeng's avatar Zheng RuiFeng Committed by Felix Cheung
Browse files

[SPARK-20849][DOC][SPARKR] Document R DecisionTree

## What changes were proposed in this pull request?
1, add an example for sparkr `decisionTree`
2, document it in user guide

## How was this patch tested?
local submit

Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes #18067 from zhengruifeng/dt_example.
parent 8ce0d8ff
No related branches found
No related tags found
No related merge requests found
......@@ -503,6 +503,8 @@ SparkR supports the following machine learning models and algorithms.
#### Tree - Classification and Regression
* Decision Tree
* Gradient-Boosted Trees (GBT)
* Random Forest
......@@ -776,16 +778,32 @@ newDF <- createDataFrame(data.frame(x = c(1.5, 3.2)))
head(predict(isoregModel, newDF))
```
#### Decision Tree
`spark.decisionTree` fits a [decision tree](https://en.wikipedia.org/wiki/Decision_tree_learning) classification or regression model on a `SparkDataFrame`.
Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models.
We use the `Titanic` dataset to train a decision tree and make predictions:
```{r}
t <- as.data.frame(Titanic)
df <- createDataFrame(t)
dtModel <- spark.decisionTree(df, Survived ~ ., type = "classification", maxDepth = 2)
summary(dtModel)
predictions <- predict(dtModel, df)
```
#### Gradient-Boosted Trees
`spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`.
Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models.
We use the `longley` dataset to train a gradient-boosted tree and make predictions:
We use the `Titanic` dataset to train a gradient-boosted tree and make predictions:
```{r, warning=FALSE}
df <- createDataFrame(longley)
gbtModel <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 2, maxIter = 2)
```{r}
t <- as.data.frame(Titanic)
df <- createDataFrame(t)
gbtModel <- spark.gbt(df, Survived ~ ., type = "classification", maxDepth = 2, maxIter = 2)
summary(gbtModel)
predictions <- predict(gbtModel, df)
```
......@@ -795,11 +813,12 @@ predictions <- predict(gbtModel, df)
`spark.randomForest` fits a [random forest](https://en.wikipedia.org/wiki/Random_forest) classification or regression model on a `SparkDataFrame`.
Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models.
In the following example, we use the `longley` dataset to train a random forest and make predictions:
In the following example, we use the `Titanic` dataset to train a random forest and make predictions:
```{r, warning=FALSE}
df <- createDataFrame(longley)
rfModel <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 2, numTrees = 2)
```{r}
t <- as.data.frame(Titanic)
df <- createDataFrame(t)
rfModel <- spark.randomForest(df, Survived ~ ., type = "classification", maxDepth = 2, numTrees = 2)
summary(rfModel)
predictions <- predict(rfModel, df)
```
......@@ -965,17 +984,18 @@ Given a `SparkDataFrame`, the test compares continuous data in a given column `t
specified by parameter `nullHypothesis`.
Users can call `summary` to get a summary of the test results.
In the following example, we test whether the `longley` dataset's `Armed_Forces` column
In the following example, we test whether the `Titanic` dataset's `Freq` column
follows a normal distribution. We set the parameters of the normal distribution using
the mean and standard deviation of the sample.
```{r, warning=FALSE}
df <- createDataFrame(longley)
afStats <- head(select(df, mean(df$Armed_Forces), sd(df$Armed_Forces)))
afMean <- afStats[1]
afStd <- afStats[2]
```{r}
t <- as.data.frame(Titanic)
df <- createDataFrame(t)
freqStats <- head(select(df, mean(df$Freq), sd(df$Freq)))
freqMean <- freqStats[1]
freqStd <- freqStats[2]
test <- spark.kstest(df, "Armed_Forces", "norm", c(afMean, afStd))
test <- spark.kstest(df, "Freq", "norm", c(freqMean, freqStd))
testSummary <- summary(test)
testSummary
```
......
......@@ -708,6 +708,13 @@ More details on parameters can be found in the [Python API documentation](api/py
{% include_example python/ml/decision_tree_regression_example.py %}
</div>
<div data-lang="r" markdown="1">
Refer to the [R API docs](api/R/spark.decisionTree.html) for more details.
{% include_example regression r/ml/decisionTree.R %}
</div>
</div>
......
......@@ -492,6 +492,7 @@ SparkR supports the following machine learning algorithms currently:
#### Tree
* [`spark.decisionTree`](api/R/spark.decisionTree.html): `Decision Tree for` [`Regression`](ml-classification-regression.html#decision-tree-regression) `and` [`Classification`](ml-classification-regression.html#decision-tree-classifier)
* [`spark.gbt`](api/R/spark.gbt.html): `Gradient Boosted Trees for` [`Regression`](ml-classification-regression.html#gradient-boosted-tree-regression) `and` [`Classification`](ml-classification-regression.html#gradient-boosted-tree-classifier)
* [`spark.randomForest`](api/R/spark.randomForest.html): `Random Forest for` [`Regression`](ml-classification-regression.html#random-forest-regression) `and` [`Classification`](ml-classification-regression.html#random-forest-classifier)
......
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# To run this example use
# ./bin/spark-submit examples/src/main/r/ml/decisionTree.R
# Load SparkR library into your R session
library(SparkR)
# Initialize SparkSession
sparkR.session(appName = "SparkR-ML-decisionTree-example")
# DecisionTree classification model
# $example on:classification$
# Load training data
df <- read.df("data/mllib/sample_libsvm_data.txt", source = "libsvm")
training <- df
test <- df
# Fit a DecisionTree classification model with spark.decisionTree
model <- spark.decisionTree(training, label ~ features, "classification")
# Model summary
summary(model)
# Prediction
predictions <- predict(model, test)
head(predictions)
# $example off:classification$
# DecisionTree regression model
# $example on:regression$
# Load training data
df <- read.df("data/mllib/sample_linear_regression_data.txt", source = "libsvm")
training <- df
test <- df
# Fit a DecisionTree regression model with spark.decisionTree
model <- spark.decisionTree(training, label ~ features, "regression")
# Model summary
summary(model)
# Prediction
predictions <- predict(model, test)
head(predictions)
# $example off:regression$
sparkR.session.stop()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment