Skip to content
Snippets Groups Projects
Commit 2a55cb41 authored by MechCoder's avatar MechCoder Committed by Joseph K. Bradley
Browse files

[SPARK-5972] [MLlib] Cache residuals and gradient in GBT during training and validation

The previous PR https://github.com/apache/spark/pull/4906 helped to extract the learning curve giving the error for each iteration. This continues the work refactoring some code and extending the same logic during training and validation.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #5330 from MechCoder/spark-5972 and squashes the following commits:

0b5d659 [MechCoder] minor
32d409d [MechCoder] EvaluateeachIteration and training cache should follow different paths
d542bb0 [MechCoder] Remove unused imports and docs
58f4932 [MechCoder] Remove unpersist
70d3b4c [MechCoder] Broadcast for each tree
5869533 [MechCoder] Access broadcasted values locally and other minor changes
923dbf6 [MechCoder] [SPARK-5972] Cache residuals and gradient in GBT during training and validation
parent 3a205bbd
No related branches found
No related tags found
No related merge requests found
......@@ -157,7 +157,6 @@ object GradientBoostedTrees extends Logging {
validationInput: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy,
validate: Boolean): GradientBoostedTreesModel = {
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
......@@ -192,20 +191,29 @@ object GradientBoostedTrees extends Logging {
// Initialize tree
timer.start("building tree 0")
val firstTreeModel = new DecisionTree(treeStrategy).run(data)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = 1.0
val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
logDebug("error of gbt = " + loss.computeError(startingModel, input))
baseLearnerWeights(0) = firstTreeWeight
val startingModel = new GradientBoostedTreesModel(
Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1))
var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
logDebug("error of gbt = " + predError.values.mean())
// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")
var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0
var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
var bestM = 1
// psuedo-residual for second iteration
data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
point.features))
// pseudo-residual for second iteration
data = predError.zip(input).map { case ((pred, _), point) =>
LabeledPoint(-loss.gradient(pred, point.label), point.features)
}
var m = 1
while (m < numIterations) {
timer.start(s"building tree $m")
......@@ -222,15 +230,22 @@ object GradientBoostedTrees extends Logging {
baseLearnerWeights(m) = learningRate
// Note: A model of type regression is used since we require raw prediction
val partialModel = new GradientBoostedTreesModel(
Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
logDebug("error of gbt = " + loss.computeError(partialModel, input))
Regression, baseLearners.slice(0, m + 1),
baseLearnerWeights.slice(0, m + 1))
predError = GradientBoostedTreesModel.updatePredictionError(
input, predError, baseLearnerWeights(m), baseLearners(m), loss)
logDebug("error of gbt = " + predError.values.mean())
if (validate) {
// Stop training early if
// 1. Reduction in error is less than the validationTol or
// 2. If the error increases, that is if the model is overfit.
// We want the model returned corresponding to the best validation error.
val currentValidateError = loss.computeError(partialModel, validationInput)
validatePredError = GradientBoostedTreesModel.updatePredictionError(
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
val currentValidateError = validatePredError.values.mean()
if (bestValidateError - currentValidateError < validationTol) {
return new GradientBoostedTreesModel(
boostingStrategy.treeStrategy.algo,
......@@ -242,8 +257,9 @@ object GradientBoostedTrees extends Logging {
}
}
// Update data with pseudo-residuals
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
point.features))
data = predError.zip(input).map { case ((pred, _), point) =>
LabeledPoint(-loss.gradient(pred, point.label), point.features)
}
m += 1
}
......
......@@ -37,14 +37,12 @@ object AbsoluteError extends Loss {
* Method to calculate the gradients for the gradient boosting calculation for least
* absolute error calculation.
* The gradient with respect to F(x) is: sign(F(x) - y)
* @param model Ensemble model
* @param point Instance of the training dataset
* @param prediction Predicted label.
* @param label True label.
* @return Loss gradient
*/
override def gradient(
model: TreeEnsembleModel,
point: LabeledPoint): Double = {
if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
override def gradient(prediction: Double, label: Double): Double = {
if (label - prediction < 0) 1.0 else -1.0
}
override def computeError(prediction: Double, label: Double): Double = {
......
......@@ -39,15 +39,12 @@ object LogLoss extends Loss {
* Method to calculate the loss gradients for the gradient boosting calculation for binary
* classification
* The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x)))
* @param model Ensemble model
* @param point Instance of the training dataset
* @param prediction Predicted label.
* @param label True label.
* @return Loss gradient
*/
override def gradient(
model: TreeEnsembleModel,
point: LabeledPoint): Double = {
val prediction = model.predict(point.features)
- 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
override def gradient(prediction: Double, label: Double): Double = {
- 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
}
override def computeError(prediction: Double, label: Double): Double = {
......
......@@ -31,13 +31,11 @@ trait Loss extends Serializable {
/**
* Method to calculate the gradients for the gradient boosting calculation.
* @param model Model of the weak learner.
* @param point Instance of the training dataset.
* @param prediction Predicted feature
* @param label true label.
* @return Loss gradient.
*/
def gradient(
model: TreeEnsembleModel,
point: LabeledPoint): Double
def gradient(prediction: Double, label: Double): Double
/**
* Method to calculate error of the base learner for the gradient boosting calculation.
......
......@@ -37,14 +37,12 @@ object SquaredError extends Loss {
* Method to calculate the gradients for the gradient boosting calculation for least
* squares error calculation.
* The gradient with respect to F(x) is: - 2 (y - F(x))
* @param model Ensemble model
* @param point Instance of the training dataset
* @param prediction Predicted label.
* @param label True label.
* @return Loss gradient
*/
override def gradient(
model: TreeEnsembleModel,
point: LabeledPoint): Double = {
2.0 * (model.predict(point.features) - point.label)
override def gradient(prediction: Double, label: Double): Double = {
2.0 * (prediction - label)
}
override def computeError(prediction: Double, label: Double): Double = {
......
......@@ -130,35 +130,28 @@ class GradientBoostedTreesModel(
val numIterations = trees.length
val evaluationArray = Array.fill(numIterations)(0.0)
val localTreeWeights = treeWeights
var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError(
remappedData, localTreeWeights(0), trees(0), loss)
var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
val pred = treeWeights(0) * trees(0).predict(i.features)
val error = loss.computeError(pred, i.label)
(pred, error)
}
evaluationArray(0) = predictionAndError.values.mean()
// Avoid the model being copied across numIterations.
val broadcastTrees = sc.broadcast(trees)
val broadcastWeights = sc.broadcast(treeWeights)
(1 until numIterations).map { nTree =>
predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
val currentTree = broadcastTrees.value(nTree)
val currentTreeWeight = broadcastWeights.value(nTree)
iter.map {
case (point, (pred, error)) => {
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
val newError = loss.computeError(newPred, point.label)
(newPred, newError)
}
val currentTreeWeight = localTreeWeights(nTree)
iter.map { case (point, (pred, error)) =>
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
val newError = loss.computeError(newPred, point.label)
(newPred, newError)
}
}
evaluationArray(nTree) = predictionAndError.values.mean()
}
broadcastTrees.unpersist()
broadcastWeights.unpersist()
evaluationArray
}
......@@ -166,6 +159,58 @@ class GradientBoostedTreesModel(
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
/**
* Compute the initial predictions and errors for a dataset for the first
* iteration of gradient boosting.
* @param data: training data.
* @param initTreeWeight: learning rate assigned to the first tree.
* @param initTree: first DecisionTreeModel.
* @param loss: evaluation metric.
* @return a RDD with each element being a zip of the prediction and error
* corresponding to every sample.
*/
def computeInitialPredictionAndError(
data: RDD[LabeledPoint],
initTreeWeight: Double,
initTree: DecisionTreeModel,
loss: Loss): RDD[(Double, Double)] = {
data.map { lp =>
val pred = initTreeWeight * initTree.predict(lp.features)
val error = loss.computeError(pred, lp.label)
(pred, error)
}
}
/**
* Update a zipped predictionError RDD
* (as obtained with computeInitialPredictionAndError)
* @param data: training data.
* @param predictionAndError: predictionError RDD
* @param treeWeight: Learning rate.
* @param tree: Tree using which the prediction and error should be updated.
* @param loss: evaluation metric.
* @return a RDD with each element being a zip of the prediction and error
* corresponding to each sample.
*/
def updatePredictionError(
data: RDD[LabeledPoint],
predictionAndError: RDD[(Double, Double)],
treeWeight: Double,
tree: DecisionTreeModel,
loss: Loss): RDD[(Double, Double)] = {
val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
iter.map {
case (lp, (pred, error)) => {
val newPred = pred + tree.predict(lp.features) * treeWeight
val newError = loss.computeError(newPred, lp.label)
(newPred, newError)
}
}
}
newPredError
}
override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment