Skip to content
Snippets Groups Projects
Commit dafd70fb authored by sethah's avatar sethah Committed by Nick Pentreath
Browse files

[SPARK-12379][ML][MLLIB] Copy GBT implementation to spark.ml

Currently, GBTs in spark.ml wrap the implementation in spark.mllib. This is preventing several improvements to GBTs in spark.ml, so we need to move the implementation to ml and use spark.ml decision trees in the implementation. At first, we should make minimal changes to the implementation.
Performance testing should be done to ensure there were no regressions.

Performance testing results are [here](https://docs.google.com/document/d/1dYd2mnfGdUKkQ3vZe2BpzsTnI5IrpSLQ-NNKDZhUkgw/edit?usp=sharing)

Author: sethah <seth.hendrickson16@gmail.com>

Closes #10607 from sethah/SPARK-12379.
parent 10251a74
No related branches found
No related tags found
No related merge requests found
Showing
with 306 additions and 15 deletions
...@@ -93,6 +93,14 @@ final class DecisionTreeClassifier @Since("1.4.0") ( ...@@ -93,6 +93,14 @@ final class DecisionTreeClassifier @Since("1.4.0") (
trees.head.asInstanceOf[DecisionTreeClassificationModel] trees.head.asInstanceOf[DecisionTreeClassificationModel]
} }
/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeClassificationModel]
}
/** (private[ml]) Create a Strategy instance to use with the old API. */ /** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy( private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int], categoricalFeatures: Map[Int, Int],
......
...@@ -26,10 +26,10 @@ import org.apache.spark.ml.param.{Param, ParamMap} ...@@ -26,10 +26,10 @@ import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams, import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams,
TreeEnsembleModel} TreeEnsembleModel}
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
...@@ -158,9 +158,8 @@ final class GBTClassifier @Since("1.4.0") ( ...@@ -158,9 +158,8 @@ final class GBTClassifier @Since("1.4.0") (
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numFeatures = oldDataset.first().features.size val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val oldGBT = new OldGBT(boostingStrategy) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy)
val oldModel = oldGBT.run(oldDataset) new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
} }
@Since("1.4.1") @Since("1.4.1")
......
...@@ -87,6 +87,14 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val ...@@ -87,6 +87,14 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
trees.head.asInstanceOf[DecisionTreeRegressionModel] trees.head.asInstanceOf[DecisionTreeRegressionModel]
} }
/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel]
}
/** (private[ml]) Create a Strategy instance to use with the old API. */ /** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
......
...@@ -25,10 +25,10 @@ import org.apache.spark.ml.{PredictionModel, Predictor} ...@@ -25,10 +25,10 @@ import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel,
TreeRegressorParams} TreeRegressorParams}
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss,
SquaredError => OldSquaredError} SquaredError => OldSquaredError}
...@@ -145,9 +145,8 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri ...@@ -145,9 +145,8 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numFeatures = oldDataset.first().features.size val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
val oldGBT = new OldGBT(boostingStrategy) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy)
val oldModel = oldGBT.run(oldDataset) new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
} }
@Since("1.4.0") @Since("1.4.0")
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.tree.impl
import org.apache.spark.Logging
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
import org.apache.spark.mllib.tree.impl.TimeTracker
import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
private[ml] object GradientBoostedTrees extends Logging {
/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
def run(input: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy
): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case OldAlgo.Regression =>
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
}
}
/**
* Method to validate a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @param validationInput Validation dataset.
* This dataset should be different from the training dataset,
* but it should follow the same distribution.
* E.g., these two datasets could be created from an original dataset
* by using [[org.apache.spark.rdd.RDD.randomSplit()]]
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
def runWithValidation(
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy
): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case OldAlgo.Regression =>
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true)
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(
x => new LabeledPoint((x.label * 2) - 1, x.features))
val remappedValidationInput = validationInput.map(
x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
validate = true)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
}
/**
* Compute the initial predictions and errors for a dataset for the first
* iteration of gradient boosting.
* @param data: training data.
* @param initTreeWeight: learning rate assigned to the first tree.
* @param initTree: first DecisionTreeModel.
* @param loss: evaluation metric.
* @return a RDD with each element being a zip of the prediction and error
* corresponding to every sample.
*/
def computeInitialPredictionAndError(
data: RDD[LabeledPoint],
initTreeWeight: Double,
initTree: DecisionTreeRegressionModel,
loss: OldLoss): RDD[(Double, Double)] = {
data.map { lp =>
val pred = initTreeWeight * initTree.rootNode.predictImpl(lp.features).prediction
val error = loss.computeError(pred, lp.label)
(pred, error)
}
}
/**
* Update a zipped predictionError RDD
* (as obtained with computeInitialPredictionAndError)
* @param data: training data.
* @param predictionAndError: predictionError RDD
* @param treeWeight: Learning rate.
* @param tree: Tree using which the prediction and error should be updated.
* @param loss: evaluation metric.
* @return a RDD with each element being a zip of the prediction and error
* corresponding to each sample.
*/
def updatePredictionError(
data: RDD[LabeledPoint],
predictionAndError: RDD[(Double, Double)],
treeWeight: Double,
tree: DecisionTreeRegressionModel,
loss: OldLoss): RDD[(Double, Double)] = {
val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
iter.map { case (lp, (pred, error)) =>
val newPred = pred + tree.rootNode.predictImpl(lp.features).prediction * treeWeight
val newError = loss.computeError(newPred, lp.label)
(newPred, newError)
}
}
newPredError
}
/**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
* @param validationInput validation dataset, ignored if validate is set to false.
* @param boostingStrategy boosting parameters
* @param validate whether or not to use the validation dataset.
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
def boost(
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy,
validate: Boolean): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
boostingStrategy.assertValid()
// Initialize gradient boosting parameters
val numIterations = boostingStrategy.numIterations
val baseLearners = new Array[DecisionTreeRegressionModel](numIterations)
val baseLearnerWeights = new Array[Double](numIterations)
val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate
// Prepare strategy for individual trees, which use regression with variance impurity.
val treeStrategy = boostingStrategy.treeStrategy.copy
val validationTol = boostingStrategy.validationTol
treeStrategy.algo = OldAlgo.Regression
treeStrategy.impurity = OldVariance
treeStrategy.assertValid()
// Cache input
val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) {
input.persist(StorageLevel.MEMORY_AND_DISK)
true
} else {
false
}
// Prepare periodic checkpointers
val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
treeStrategy.getCheckpointInterval, input.sparkContext)
val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
treeStrategy.getCheckpointInterval, input.sparkContext)
timer.stop("init")
logDebug("##########")
logDebug("Building tree 0")
logDebug("##########")
// Initialize tree
timer.start("building tree 0")
val firstTree = new DecisionTreeRegressor()
val firstTreeModel = firstTree.train(input, treeStrategy)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
var predError: RDD[(Double, Double)] =
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
predErrorCheckpointer.update(predError)
logDebug("error of gbt = " + predError.values.mean())
// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")
var validatePredError: RDD[(Double, Double)] =
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
if (validate) validatePredErrorCheckpointer.update(validatePredError)
var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
var bestM = 1
var m = 1
var doneLearning = false
while (m < numIterations && !doneLearning) {
// Update data with pseudo-residuals
val data = predError.zip(input).map { case ((pred, _), point) =>
LabeledPoint(-loss.gradient(pred, point.label), point.features)
}
timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
val dt = new DecisionTreeRegressor()
val model = dt.train(data, treeStrategy)
timer.stop(s"building tree $m")
// Update partial model
baseLearners(m) = model
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
// Technically, the weight should be optimized for the particular loss.
// However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate
predError = updatePredictionError(
input, predError, baseLearnerWeights(m), baseLearners(m), loss)
predErrorCheckpointer.update(predError)
logDebug("error of gbt = " + predError.values.mean())
if (validate) {
// Stop training early if
// 1. Reduction in error is less than the validationTol or
// 2. If the error increases, that is if the model is overfit.
// We want the model returned corresponding to the best validation error.
validatePredError = updatePredictionError(
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
validatePredErrorCheckpointer.update(validatePredError)
val currentValidateError = validatePredError.values.mean()
if (bestValidateError - currentValidateError < validationTol * Math.max(
currentValidateError, 0.01)) {
doneLearning = true
} else if (currentValidateError < bestValidateError) {
bestValidateError = currentValidateError
bestM = m + 1
}
}
m += 1
}
timer.stop("total")
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
predErrorCheckpointer.deleteAllCheckpoints()
validatePredErrorCheckpointer.deleteAllCheckpoints()
if (persistedInput) input.unpersist()
if (validate) {
(baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM))
} else {
(baseLearners, baseLearnerWeights)
}
}
}
...@@ -74,7 +74,7 @@ import org.apache.spark.storage.StorageLevel ...@@ -74,7 +74,7 @@ import org.apache.spark.storage.StorageLevel
* *
* TODO: Move this out of MLlib? * TODO: Move this out of MLlib?
*/ */
private[mllib] class PeriodicRDDCheckpointer[T]( private[spark] class PeriodicRDDCheckpointer[T](
checkpointInterval: Int, checkpointInterval: Int,
sc: SparkContext) sc: SparkContext)
extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) {
......
...@@ -59,7 +59,7 @@ case class BoostingStrategy @Since("1.4.0") ( ...@@ -59,7 +59,7 @@ case class BoostingStrategy @Since("1.4.0") (
* Check validity of parameters. * Check validity of parameters.
* Throws exception if invalid. * Throws exception if invalid.
*/ */
private[tree] def assertValid(): Unit = { private[spark] def assertValid(): Unit = {
treeStrategy.algo match { treeStrategy.algo match {
case Classification => case Classification =>
require(treeStrategy.numClasses == 2, require(treeStrategy.numClasses == 2,
......
...@@ -133,7 +133,7 @@ class Strategy @Since("1.3.0") ( ...@@ -133,7 +133,7 @@ class Strategy @Since("1.3.0") (
* Check validity of parameters. * Check validity of parameters.
* Throws exception if invalid. * Throws exception if invalid.
*/ */
private[tree] def assertValid(): Unit = { private[spark] def assertValid(): Unit = {
algo match { algo match {
case Classification => case Classification =>
require(numClasses >= 2, require(numClasses >= 2,
......
...@@ -45,7 +45,7 @@ object AbsoluteError extends Loss { ...@@ -45,7 +45,7 @@ object AbsoluteError extends Loss {
if (label - prediction < 0) 1.0 else -1.0 if (label - prediction < 0) 1.0 else -1.0
} }
override private[mllib] def computeError(prediction: Double, label: Double): Double = { override private[spark] def computeError(prediction: Double, label: Double): Double = {
val err = label - prediction val err = label - prediction
math.abs(err) math.abs(err)
} }
......
...@@ -47,7 +47,7 @@ object LogLoss extends Loss { ...@@ -47,7 +47,7 @@ object LogLoss extends Loss {
- 4.0 * label / (1.0 + math.exp(2.0 * label * prediction)) - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
} }
override private[mllib] def computeError(prediction: Double, label: Double): Double = { override private[spark] def computeError(prediction: Double, label: Double): Double = {
val margin = 2.0 * label * prediction val margin = 2.0 * label * prediction
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin) 2.0 * MLUtils.log1pExp(-margin)
......
...@@ -61,5 +61,5 @@ trait Loss extends Serializable { ...@@ -61,5 +61,5 @@ trait Loss extends Serializable {
* @param label True label. * @param label True label.
* @return Measure of model error on datapoint. * @return Measure of model error on datapoint.
*/ */
private[mllib] def computeError(prediction: Double, label: Double): Double private[spark] def computeError(prediction: Double, label: Double): Double
} }
...@@ -45,7 +45,7 @@ object SquaredError extends Loss { ...@@ -45,7 +45,7 @@ object SquaredError extends Loss {
- 2.0 * (label - prediction) - 2.0 * (label - prediction)
} }
override private[mllib] def computeError(prediction: Double, label: Double): Double = { override private[spark] def computeError(prediction: Double, label: Double): Double = {
val err = label - prediction val err = label - prediction
err * err err * err
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment