Skip to content
Snippets Groups Projects
Commit 2a0fe348 authored by MechCoder's avatar MechCoder Committed by Joseph K. Bradley
Browse files

[SPARK-5436] [MLlib] Validate GradientBoostedTrees using runWithValidation

One can early stop if the decrease in error rate is lesser than a certain tol or if the error increases if the training data is overfit.

This introduces a new method runWithValidation which takes in a pair of RDD's , one for the training data and the other for the validation.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #4677 from MechCoder/spark-5436 and squashes the following commits:

1bb21d4 [MechCoder] Combine regression and classification tests into a single one
e4d799b [MechCoder] Addresses indentation and doc comments
b48a70f [MechCoder] COSMIT
b928a19 [MechCoder] Move validation while training section under usage tips
fad9b6e [MechCoder] Made the following changes 1. Add section to documentation 2. Return corresponding to bestValidationError 3. Allow negative tolerance.
55e5c3b [MechCoder] One liner for prevValidateError
3e74372 [MechCoder] TST: Add test for classification
77549a9 [MechCoder] [SPARK-5436] Validate GradientBoostedTrees using runWithValidation
parent da505e59
No related branches found
No related tags found
No related merge requests found
......@@ -427,6 +427,17 @@ We omit some decision tree parameters since those are covered in the [decision t
* **`algo`**: The algorithm or task (classification vs. regression) is set using the tree [Strategy] parameter.
#### Validation while training
Gradient boosting can overfit when trained with more trees. In order to prevent overfitting, it is useful to validate while
training. The method runWithValidation has been provided to make use of this option. It takes a pair of RDD's as arguments, the
first one being the training dataset and the second being the validation dataset.
The training is stopped when the improvement in the validation error is not more than a certain tolerance
(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error
decreases initially and later increases. There might be cases in which the validation error does not change monotonically,
and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of
iterations.
### Examples
......
......@@ -60,11 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case Regression => GradientBoostedTrees.boost(input, boostingStrategy)
case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, boostingStrategy)
GradientBoostedTrees.boost(remappedInput,
remappedInput, boostingStrategy, validate=false)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
......@@ -76,8 +77,46 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
run(input.rdd)
}
}
/**
* Method to validate a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @param validationInput Validation dataset:
RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
Should be different from and follow the same distribution as input.
e.g., these two datasets could be created from an original dataset
by using [[org.apache.spark.rdd.RDD.randomSplit()]]
* @return a gradient boosted trees model that can be used for prediction
*/
def runWithValidation(
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case Regression => GradientBoostedTrees.boost(
input, validationInput, boostingStrategy, validate=true)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(
x => new LabeledPoint((x.label * 2) - 1, x.features))
val remappedValidationInput = validationInput.map(
x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
validate=true)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
}
/**
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]].
*/
def runWithValidation(
input: JavaRDD[LabeledPoint],
validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
runWithValidation(input.rdd, validationInput.rdd)
}
}
object GradientBoostedTrees extends Logging {
......@@ -108,12 +147,16 @@ object GradientBoostedTrees extends Logging {
/**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
* @param validationInput validation dataset, ignored if validate is set to false.
* @param boostingStrategy boosting parameters
* @param validate whether or not to use the validation dataset.
* @return a gradient boosted trees model that can be used for prediction
*/
private def boost(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
validationInput: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy,
validate: Boolean): GradientBoostedTreesModel = {
val timer = new TimeTracker()
timer.start("total")
......@@ -129,6 +172,7 @@ object GradientBoostedTrees extends Logging {
val learningRate = boostingStrategy.learningRate
// Prepare strategy for individual trees, which use regression with variance impurity.
val treeStrategy = boostingStrategy.treeStrategy.copy
val validationTol = boostingStrategy.validationTol
treeStrategy.algo = Regression
treeStrategy.impurity = Variance
treeStrategy.assertValid()
......@@ -152,13 +196,16 @@ object GradientBoostedTrees extends Logging {
baseLearnerWeights(0) = 1.0
val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
logDebug("error of gbt = " + loss.computeError(startingModel, input))
// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")
var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0
var bestM = 1
// psuedo-residual for second iteration
data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
point.features))
var m = 1
while (m < numIterations) {
timer.start(s"building tree $m")
......@@ -177,6 +224,23 @@ object GradientBoostedTrees extends Logging {
val partialModel = new GradientBoostedTreesModel(
Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
logDebug("error of gbt = " + loss.computeError(partialModel, input))
if (validate) {
// Stop training early if
// 1. Reduction in error is less than the validationTol or
// 2. If the error increases, that is if the model is overfit.
// We want the model returned corresponding to the best validation error.
val currentValidateError = loss.computeError(partialModel, validationInput)
if (bestValidateError - currentValidateError < validationTol) {
return new GradientBoostedTreesModel(
boostingStrategy.treeStrategy.algo,
baseLearners.slice(0, bestM),
baseLearnerWeights.slice(0, bestM))
} else if (currentValidateError < bestValidateError) {
bestValidateError = currentValidateError
bestM = m + 1
}
}
// Update data with pseudo-residuals
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
point.features))
......@@ -191,4 +255,5 @@ object GradientBoostedTrees extends Logging {
new GradientBoostedTreesModel(
boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
}
}
......@@ -34,6 +34,9 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* weak hypotheses used in the final model.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
* @param validationTol Useful when runWithValidation is used. If the error rate on the
* validation input between two iterations is less than the validationTol
* then stop. Ignored when [[run]] is used.
*/
@Experimental
case class BoostingStrategy(
......@@ -42,7 +45,8 @@ case class BoostingStrategy(
@BeanProperty var loss: Loss,
// Optional boosting parameters
@BeanProperty var numIterations: Int = 100,
@BeanProperty var learningRate: Double = 0.1) extends Serializable {
@BeanProperty var learningRate: Double = 0.1,
@BeanProperty var validationTol: Double = 1e-5) extends Serializable {
/**
* Check validity of parameters.
......
......@@ -158,6 +158,40 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
}
}
}
test("runWithValidation stops early and performs better on a validation dataset") {
// Set numIterations large enough so that it stops early.
val numIterations = 20
val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
val algos = Array(Regression, Regression, Classification)
val losses = Array(SquaredError, AbsoluteError, LogLoss)
(algos zip losses) map {
case (algo, loss) => {
val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
categoricalFeaturesInfo = Map.empty)
val boostingStrategy =
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
val gbtValidate = new GradientBoostedTrees(boostingStrategy)
.runWithValidation(trainRdd, validateRdd)
assert(gbtValidate.numTrees !== numIterations)
// Test that it performs better on the validation dataset.
val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) {
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
(loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
} else {
(loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
}
}
assert(errorWithValidation <= errorWithoutValidation)
}
}
}
}
private object GradientBoostedTreesSuite {
......@@ -166,4 +200,6 @@ private object GradientBoostedTreesSuite {
val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
val trainData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120)
val validateData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment