Skip to content
Snippets Groups Projects
Commit 07d72fe6 authored by Manish Amde's avatar Manish Amde Committed by Patrick Wendell
Browse files

Decision Tree documentation for MLlib programming guide

Added documentation for user to use the decision tree algorithms for classification and regression in Spark 1.0 release.

Apart from a general review, I need specific input on the following:
* I had to move a lot of the existing documentation under the *linear methods* umbrella to accommodate decision trees. I wonder if there is a better way to organize the programming guide given we are so close to the release.
* I have not looked closely at pyspark but I am wondering new mllib algorithms are automatically plugged in or do we need to some extra work to call mllib functions from pyspark. I will add to the pyspark examples based upon the advice I get.

cc: @mengxr, @hirakendu, @etrain, @atalwalkar

Author: Manish Amde <manish9ue@gmail.com>

Closes #402 from manishamde/tree_doc and squashes the following commits:

022485a [Manish Amde] more documentation
865826e [Manish Amde] minor: grammar
dbb0e5e [Manish Amde] minor improvements to text
b9ef6c4 [Manish Amde] basic decision tree code examples
6e297d7 [Manish Amde] added subsections
f427e84 [Manish Amde] renaming sections
9c0c4be [Manish Amde] split candidate
6925275 [Manish Amde] impurity and information gain
94fd2f9 [Manish Amde] more reorg
b93125c [Manish Amde] more subsection reorg
3ecb2ad [Manish Amde] minor text addition
1537dd3 [Manish Amde] added placeholders and some doc
d06511d [Manish Amde] basic skeleton
parent 6843d637
No related branches found
No related tags found
No related merge requests found
...@@ -40,8 +40,9 @@ Supervised Learning involves executing a learning *Algorithm* on a set of *label ...@@ -40,8 +40,9 @@ Supervised Learning involves executing a learning *Algorithm* on a set of *label
examples. The algorithm returns a trained *Model* (such as for example a linear function) that examples. The algorithm returns a trained *Model* (such as for example a linear function) that
can predict the label for new data examples for which the label is unknown. can predict the label for new data examples for which the label is unknown.
## Discriminative Training using Linear Methods
## Mathematical Formulation ### Mathematical Formulation
Many standard *machine learning* methods can be formulated as a convex optimization problem, i.e. Many standard *machine learning* methods can be formulated as a convex optimization problem, i.e.
the task of finding a minimizer of a convex function `$f$` that depends on a variable vector the task of finding a minimizer of a convex function `$f$` that depends on a variable vector
`$\wv$` (called `weights` in the code), which has `$d$` entries. `$\wv$` (called `weights` in the code), which has `$d$` entries.
...@@ -71,7 +72,7 @@ The fixed regularization parameter `$\lambda\ge0$` (`regParam` in the code) defi ...@@ -71,7 +72,7 @@ The fixed regularization parameter `$\lambda\ge0$` (`regParam` in the code) defi
between the two goals of small loss and small model complexity. between the two goals of small loss and small model complexity.
## Binary Classification ### Binary Classification
**Input:** Datapoints `$\x_i\in\R^{d}$`, labels `$y_i\in\{+1,-1\}$`, for `$1\le i\le n$`. **Input:** Datapoints `$\x_i\in\R^{d}$`, labels `$y_i\in\{+1,-1\}$`, for `$1\le i\le n$`.
...@@ -83,7 +84,7 @@ In other words, the input distributed dataset ...@@ -83,7 +84,7 @@ In other words, the input distributed dataset
([RDD](scala-programming-guide.html#resilient-distributed-datasets-rdds)) must be the set of ([RDD](scala-programming-guide.html#resilient-distributed-datasets-rdds)) must be the set of
vectors `$\x_i\in\R^d$`. vectors `$\x_i\in\R^d$`.
### Support Vector Machine #### Support Vector Machine
The linear [Support Vector Machine (SVM)](http://en.wikipedia.org/wiki/Support_vector_machine) The linear [Support Vector Machine (SVM)](http://en.wikipedia.org/wiki/Support_vector_machine)
has become a standard choice for classification tasks. has become a standard choice for classification tasks.
Here the loss function in formulation `$\eqref{eq:regPrimal}$` is given by the hinge-loss Here the loss function in formulation `$\eqref{eq:regPrimal}$` is given by the hinge-loss
...@@ -95,7 +96,7 @@ By default, SVMs are trained with an L2 regularization, which gives rise to the ...@@ -95,7 +96,7 @@ By default, SVMs are trained with an L2 regularization, which gives rise to the
interpretation if these classifiers. We also support alternative L1 regularization. In this case, interpretation if these classifiers. We also support alternative L1 regularization. In this case,
the primal optimization problem becomes an [LP](http://en.wikipedia.org/wiki/Linear_programming). the primal optimization problem becomes an [LP](http://en.wikipedia.org/wiki/Linear_programming).
### Logistic Regression #### Logistic Regression
Despite its name, [Logistic Regression](http://en.wikipedia.org/wiki/Logistic_regression) is a Despite its name, [Logistic Regression](http://en.wikipedia.org/wiki/Logistic_regression) is a
binary classification method, again when the labels are given by binary values binary classification method, again when the labels are given by binary values
`$y_i\in\{+1,-1\}$`. The logistic loss function in formulation `$\eqref{eq:regPrimal}$` is `$y_i\in\{+1,-1\}$`. The logistic loss function in formulation `$\eqref{eq:regPrimal}$` is
...@@ -105,7 +106,7 @@ L(\wv;\x_i,y_i) := \log(1+\exp( -y_i \wv^T \x_i)) \ . ...@@ -105,7 +106,7 @@ L(\wv;\x_i,y_i) := \log(1+\exp( -y_i \wv^T \x_i)) \ .
\]` \]`
## Linear Regression (Least Squares, Lasso and Ridge Regression) ### Linear Regression (Least Squares, Lasso and Ridge Regression)
**Input:** Data matrix `$A\in\R^{n\times d}$`, right hand side vector `$\y\in\R^n$`. **Input:** Data matrix `$A\in\R^{n\times d}$`, right hand side vector `$\y\in\R^n$`.
...@@ -121,17 +122,17 @@ linear combination of our observed data `$A\in\R^{n\times d}$`, which is given a ...@@ -121,17 +122,17 @@ linear combination of our observed data `$A\in\R^{n\times d}$`, which is given a
It comes in 3 flavors: It comes in 3 flavors:
### Least Squares #### Least Squares
Plain old [least squares](http://en.wikipedia.org/wiki/Least_squares) linear regression is the Plain old [least squares](http://en.wikipedia.org/wiki/Least_squares) linear regression is the
problem of minimizing problem of minimizing
`\[ f_{\text{LS}}(\wv) := \frac1n \|A\wv-\y\|_2^2 \ . \]` `\[ f_{\text{LS}}(\wv) := \frac1n \|A\wv-\y\|_2^2 \ . \]`
### Lasso #### Lasso
The popular [Lasso](http://en.wikipedia.org/wiki/Lasso_(statistics)#Lasso_method) (alternatively The popular [Lasso](http://en.wikipedia.org/wiki/Lasso_(statistics)#Lasso_method) (alternatively
also known as `$L_1$`-regularized least squares regression) is given by also known as `$L_1$`-regularized least squares regression) is given by
`\[ f_{\text{Lasso}}(\wv) := \frac1n \|A\wv-\y\|_2^2 + \lambda \|\wv\|_1 \ . \]` `\[ f_{\text{Lasso}}(\wv) := \frac1n \|A\wv-\y\|_2^2 + \lambda \|\wv\|_1 \ . \]`
### Ridge Regression #### Ridge Regression
[Ridge regression](http://en.wikipedia.org/wiki/Ridge_regression) uses the same loss function but [Ridge regression](http://en.wikipedia.org/wiki/Ridge_regression) uses the same loss function but
with a L2 regularizer term: with a L2 regularizer term:
`\[ f_{\text{Ridge}}(\wv) := \frac1n \|A\wv-\y\|_2^2 + \frac{\lambda}{2}\|\wv\|^2 \ . \]` `\[ f_{\text{Ridge}}(\wv) := \frac1n \|A\wv-\y\|_2^2 + \frac{\lambda}{2}\|\wv\|^2 \ . \]`
...@@ -150,7 +151,7 @@ In our generic problem formulation `$\eqref{eq:regPrimal}$`, this means the loss ...@@ -150,7 +151,7 @@ In our generic problem formulation `$\eqref{eq:regPrimal}$`, this means the loss
the data matrix `$A$`. the data matrix `$A$`.
## Using Different Regularizers ### Using Different Regularizers
As we have mentioned above, the purpose of *regularizer* in `$\eqref{eq:regPrimal}$` is to As we have mentioned above, the purpose of *regularizer* in `$\eqref{eq:regPrimal}$` is to
encourage simple models, by punishing the complexity of the model `$\wv$`, in order to e.g. avoid encourage simple models, by punishing the complexity of the model `$\wv$`, in order to e.g. avoid
...@@ -178,7 +179,7 @@ the 3 mentioned here can be conveniently optimized with gradient descent type me ...@@ -178,7 +179,7 @@ the 3 mentioned here can be conveniently optimized with gradient descent type me
SGD) which is implemented in `MLlib` currently, and explained in the next section. SGD) which is implemented in `MLlib` currently, and explained in the next section.
# Optimization Methods Working on the Primal Formulation ### Optimization Methods Working on the Primal Formulation
**Stochastic subGradient Descent (SGD).** **Stochastic subGradient Descent (SGD).**
For optimization objectives `$f$` written as a sum, *stochastic subgradient descent (SGD)* can be For optimization objectives `$f$` written as a sum, *stochastic subgradient descent (SGD)* can be
...@@ -239,11 +240,72 @@ Here `$\mathop{sign}(\wv)$` is the vector consisting of the signs (`$\pm1$`) of ...@@ -239,11 +240,72 @@ Here `$\mathop{sign}(\wv)$` is the vector consisting of the signs (`$\pm1$`) of
of `$\wv$`. of `$\wv$`.
Also, note that `$A_{i:} \in \R^d$` is a row-vector, but the gradient is a column vector. Also, note that `$A_{i:} \in \R^d$` is a row-vector, but the gradient is a column vector.
## Decision Tree Classification and Regression
Decision trees and their ensembles are popular methods for the machine learning tasks of classification and regression. Decision trees are widely used since they are easy to interpret, handle categorical variables, extend to the multi-class classification setting, do not require feature scaling and are able to capture non-linearities and feature interactions. Tree ensemble algorithms such as decision forest and boosting are among the top performers for classification and regression tasks.
### Basic Algorithm
The decision tree is a greedy algorithm that performs a recursive binary partitioning of the feature space by choosing a single element from the *best split set* where each element of the set maximimizes the information gain at a tree node. In other words, the split chosen at each tree node is chosen from the set `$\underset{s}{\operatorname{argmax}} IG(D,s)$` where `$IG(D,s)$` is the information gain when a split `$s$` is applied to a dataset `$D$`.
#### Node Impurity and Information Gain
The *node impurity* is a measure of the homogeneity of the labels at the node. The current implementation provides two impurity measures for classification (Gini index and entropy) and one impurity measure for regression (variance).
<table class="table">
<thead>
<tr><th>Impurity</th><th>Task</th><th>Formula</th><th>Description</th></tr>
</thead>
<tbody>
<tr>
<td>Gini index</td><td>Classification</td><td>$\sum_{i=1}^{M} f_i(1-f_i)$</td><td>$f_i$ is the frequency of label $i$ at a node and $M$ is the number of unique labels.</td>
</tr>
<tr>
<td>Entropy</td><td>Classification</td><td>$\sum_{i=1}^{M} -f_ilog(f_i)$</td><td>$f_i$ is the frequency of label $i$ at a node and $M$ is the number of unique labels.</td>
</tr>
<tr>
<td>Variance</td><td>Classification</td><td>$\frac{1}{n} \sum_{i=1}^{N} (x_i - \mu)^2$</td><td>$y_i$ is label for an instance, $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^n x_i$.</td>
</tr>
</tbody>
</table>
The *information gain* is the difference in the parent node impurity and the weighted sum of the two child node impurities. Assuming that a split $s$ partitions the dataset `$D$` of size `$N$` into two datasets `$D_{left}$` and `$D_{right}$` of sizes `$N_{left}$` and `$N_{right}$`, respectively:
`$IG(D,s) = Impurity(D) - \frac{N_{left}}{N} Impurity(D_{left}) - \frac{N_{right}}{N} Impurity(D_{right})$`
#### Split Candidates
**Continuous Features**
For small datasets in single machine implementations, the split candidates for each continuous feature are typically the unique values for the feature. Some implementations sort the feature values and then use the ordered unique values as split candidates for faster tree calculations.
Finding ordered unique feature values is computationally intensive for large distributed datasets. One can get an approximate set of split candidates by performing a quantile calculation over a sampled fraction of the data. The ordered splits create "bins" and the maximum number of such bins can be specified using the `maxBins` parameters.
Note that the number of bins cannot be greater than the number of instances `$N$` (a rare scenario since the default `maxBins` value is 100). The tree algorithm automatically reduces the number of bins if the condition is not satisfied.
**Categorical Features**
For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the categorical feature values by the proportion of labels falling in one of the two classes (see Section 9.2.4 in [Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for details). For example, for a binary classification problem with one categorical feature with three categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical features are orded as A followed by C followed B or A, B, C. The two split candidates are A \| C, B and A , B \| C where \| denotes the split.
#### Stopping Rule
The recursive tree construction is stopped at a node when one of the two conditions is met:
1. The node depth is equal to the `maxDepth` training paramemter
2. No split candidate leads to an information gain at the node.
### Practical Limitations
The tree implementation stores an Array[Double] of size *O(#features \* #splits \* 2^maxDepth)* in memory for aggregating histograms over partitions. The current implementation might not scale to very deep trees since the memory requirement grows exponentially with tree depth.
Please drop us a line if you encounter any issues. We are planning to solve this problem in the near future and real-world examples will be great.
## Implementation in MLlib ## Implementation in MLlib
For both classification and regression, `MLlib` implements a simple distributed version of #### Linear Methods
For both classification and regression algorithms with convex loss functions, `MLlib` implements a simple distributed version of
stochastic subgradient descent (SGD), building on the underlying gradient descent primitive (as stochastic subgradient descent (SGD), building on the underlying gradient descent primitive (as
described in the described in the
<a href="mllib-optimization.html">optimization section</a>). <a href="mllib-optimization.html">optimization section</a>).
...@@ -269,15 +331,21 @@ gradient descent primitive in MLlib, see the ...@@ -269,15 +331,21 @@ gradient descent primitive in MLlib, see the
* [GradientDescent](api/mllib/index.html#org.apache.spark.mllib.optimization.GradientDescent) * [GradientDescent](api/mllib/index.html#org.apache.spark.mllib.optimization.GradientDescent)
#### Tree-based Methods
The decision tree algorithm supports binary classification and regression:
* [DecisionTee](api/mllib/index.html#org.apache.spark.mllib.tree.DecisionTree)
# Usage in Scala # Usage in Scala
Following code snippets can be executed in `spark-shell`. Following code snippets can be executed in `spark-shell`.
## Binary Classification ## Linear Methods
#### Binary Classification
The following code snippet illustrates how to load a sample dataset, execute a The following code snippet illustrates how to load a sample dataset, execute a
training algorithm on this training data using a static method in the algorithm training algorithm on this training data using a static method in the algorithm
...@@ -328,7 +396,7 @@ svmAlg.optimizer.setNumIterations(200) ...@@ -328,7 +396,7 @@ svmAlg.optimizer.setNumIterations(200)
val modelL1 = svmAlg.run(parsedData) val modelL1 = svmAlg.run(parsedData)
{% endhighlight %} {% endhighlight %}
## Linear Regression #### Linear Regression
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint.
The example then uses LinearRegressionWithSGD to build a simple linear model to predict label The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
...@@ -363,6 +431,73 @@ println("training Mean Squared Error = " + MSE) ...@@ -363,6 +431,73 @@ println("training Mean Squared Error = " + MSE)
Similarly you can use RidgeRegressionWithSGD and LassoWithSGD and compare training Similarly you can use RidgeRegressionWithSGD and LassoWithSGD and compare training
[Mean Squared Errors](http://en.wikipedia.org/wiki/Mean_squared_error). [Mean Squared Errors](http://en.wikipedia.org/wiki/Mean_squared_error).
## Decision Tree
#### Classification
The example below demonstrates how to load a CSV file, parse it as an RDD of LabeledPoint and then perform classification using a decision tree using Gini index as an impurity measure and a maximum tree depth of 5. The training error is calculated to measure the algorithm accuracy.
{% highlight scala %}
import org.apache.spark.SparkContext
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.impurity.Gini
// Load and parse the data file
val data = sc.textFile("mllib/data/sample_tree_data.csv")
val parsedData = data.map { line =>
val parts = line.split(',').map(_.toDouble)
LabeledPoint(parts(0), Vectors.dense(parts.tail))
}
// Run training algorithm to build the model
val maxDepth = 5
val model = DecisionTree.train(parsedData, Classification, Gini, maxDepth)
// Evaluate model on training examples and compute training error
val labelAndPreds = parsedData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / parsedData.count
println("Training Error = " + trainErr)
{% endhighlight %}
#### Regression
The example below demonstrates how to load a CSV file, parse it as an RDD of LabeledPoint and then perform regression using a decision tree using variance as an impurity measure and a maximum tree depth of 5. The Mean Squared Error is computed at the end to evaluate
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).
{% highlight scala %}
import org.apache.spark.SparkContext
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.impurity.Variance
// Load and parse the data file
val data = sc.textFile("mllib/data/sample_tree_data.csv")
val parsedData = data.map { line =>
val parts = line.split(',').map(_.toDouble)
LabeledPoint(parts(0), Vectors.dense(parts.tail))
}
// Run training algorithm to build the model
val maxDepth = 5
val model = DecisionTree.train(parsedData, Regression, Variance, maxDepth)
// Evaluate model on training examples and compute training error
val valuesAndPreds = parsedData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.reduce(_ + _)/valuesAndPreds.count
println("training Mean Squared Error = " + MSE)
{% endhighlight %}
# Usage in Java # Usage in Java
...@@ -375,7 +510,9 @@ calling `.rdd()` on your `JavaRDD` object. ...@@ -375,7 +510,9 @@ calling `.rdd()` on your `JavaRDD` object.
Following examples can be tested in the PySpark shell. Following examples can be tested in the PySpark shell.
## Binary Classification ## Linear Methods
### Binary Classification
The following example shows how to load a sample dataset, build Logistic Regression model, The following example shows how to load a sample dataset, build Logistic Regression model,
and make predictions with the resulting model to compute the training error. and make predictions with the resulting model to compute the training error.
...@@ -397,7 +534,7 @@ trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedDa ...@@ -397,7 +534,7 @@ trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedDa
print("Training Error = " + str(trainErr)) print("Training Error = " + str(trainErr))
{% endhighlight %} {% endhighlight %}
## Linear Regression ### Linear Regression
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint.
The example then uses LinearRegressionWithSGD to build a simple linear model to predict label The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
values. We compute the Mean Squared Error at the end to evaluate values. We compute the Mean Squared Error at the end to evaluate
...@@ -419,4 +556,4 @@ valuesAndPreds = parsedData.map(lambda point: (point.item(0), ...@@ -419,4 +556,4 @@ valuesAndPreds = parsedData.map(lambda point: (point.item(0),
model.predict(point.take(range(1, point.size))))) model.predict(point.take(range(1, point.size)))))
MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y)/valuesAndPreds.count() MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y)/valuesAndPreds.count()
print("Mean Squared Error = " + str(MSE)) print("Mean Squared Error = " + str(MSE))
{% endhighlight %} {% endhighlight %}
\ No newline at end of file
...@@ -21,6 +21,7 @@ The following links provide a detailed explanation of the methods and usage exam ...@@ -21,6 +21,7 @@ The following links provide a detailed explanation of the methods and usage exam
* Least Squares * Least Squares
* Lasso * Lasso
* Ridge Regression * Ridge Regression
* Decision Tree (for classification and regression)
* <a href="mllib-clustering.html">Clustering</a> * <a href="mllib-clustering.html">Clustering</a>
* k-Means * k-Means
* <a href="mllib-collaborative-filtering.html">Collaborative Filtering</a> * <a href="mllib-collaborative-filtering.html">Collaborative Filtering</a>
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment