diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 7f0397f6bd65a543e129f889a6390ac81e3e40a0..bcbedc8bc108b394738aaf8bdd2df9fd44aa8d18 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -93,6 +93,14 @@ final class DecisionTreeClassifier @Since("1.4.0") ( trees.head.asInstanceOf[DecisionTreeClassificationModel] } + /** (private[ml]) Train a decision tree on an RDD */ + private[ml] def train(data: RDD[LabeledPoint], + oldStrategy: OldStrategy): DecisionTreeClassificationModel = { + val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", + seed = 0L, parentUID = Some(uid)) + trees.head.asInstanceOf[DecisionTreeClassificationModel] + } + /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index e0ffbedf6cb03da16053071f7bdc1629978ea2a0..82059b1d0ecbb0918953cb05eec204add4897c22 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -26,10 +26,10 @@ import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} @@ -158,9 +158,8 @@ final class GBTClassifier @Since("1.4.0") ( val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) - val oldGBT = new OldGBT(boostingStrategy) - val oldModel = oldGBT.run(oldDataset) - GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy) + new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 897b23383c0cb4b775f4b3e0c5c068fdd04ad08f..6e462924511e02d0e6a7b1c7740d41d69b175221 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -87,6 +87,14 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val trees.head.asInstanceOf[DecisionTreeRegressionModel] } + /** (private[ml]) Train a decision tree on an RDD */ + private[ml] def train(data: RDD[LabeledPoint], + oldStrategy: OldStrategy): DecisionTreeRegressionModel = { + val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", + seed = 0L, parentUID = Some(uid)) + trees.head.asInstanceOf[DecisionTreeRegressionModel] + } + /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 9c842a6c88202dc057546eeeae5c645f9634f359..4cc2721aefb22cbb48aba9eebe0bcf1e1ea1e90e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -25,10 +25,10 @@ import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, SquaredError => OldSquaredError} @@ -145,9 +145,8 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val oldGBT = new OldGBT(boostingStrategy) - val oldModel = oldGBT.run(oldDataset) - GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy) + new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala new file mode 100644 index 0000000000000000000000000000000000000000..44ab5b723bd7abef1461a4143bfbceb55190003c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.Logging +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} +import org.apache.spark.mllib.tree.impl.TimeTracker +import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} +import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +private[ml] object GradientBoostedTrees extends Logging { + + /** + * Method to train a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return tuple of ensemble models and weights: + * (array of decision tree models, array of model weights) + */ + def run(input: RDD[LabeledPoint], + boostingStrategy: OldBoostingStrategy + ): (Array[DecisionTreeRegressionModel], Array[Double]) = { + val algo = boostingStrategy.treeStrategy.algo + algo match { + case OldAlgo.Regression => + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false) + case OldAlgo.Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.") + } + } + + /** + * Method to validate a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param validationInput Validation dataset. + * This dataset should be different from the training dataset, + * but it should follow the same distribution. + * E.g., these two datasets could be created from an original dataset + * by using [[org.apache.spark.rdd.RDD.randomSplit()]] + * @return tuple of ensemble models and weights: + * (array of decision tree models, array of model weights) + */ + def runWithValidation( + input: RDD[LabeledPoint], + validationInput: RDD[LabeledPoint], + boostingStrategy: OldBoostingStrategy + ): (Array[DecisionTreeRegressionModel], Array[Double]) = { + val algo = boostingStrategy.treeStrategy.algo + algo match { + case OldAlgo.Regression => + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true) + case OldAlgo.Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedInput = input.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + val remappedValidationInput = validationInput.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, + validate = true) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") + } + } + + /** + * Compute the initial predictions and errors for a dataset for the first + * iteration of gradient boosting. + * @param data: training data. + * @param initTreeWeight: learning rate assigned to the first tree. + * @param initTree: first DecisionTreeModel. + * @param loss: evaluation metric. + * @return a RDD with each element being a zip of the prediction and error + * corresponding to every sample. + */ + def computeInitialPredictionAndError( + data: RDD[LabeledPoint], + initTreeWeight: Double, + initTree: DecisionTreeRegressionModel, + loss: OldLoss): RDD[(Double, Double)] = { + data.map { lp => + val pred = initTreeWeight * initTree.rootNode.predictImpl(lp.features).prediction + val error = loss.computeError(pred, lp.label) + (pred, error) + } + } + + /** + * Update a zipped predictionError RDD + * (as obtained with computeInitialPredictionAndError) + * @param data: training data. + * @param predictionAndError: predictionError RDD + * @param treeWeight: Learning rate. + * @param tree: Tree using which the prediction and error should be updated. + * @param loss: evaluation metric. + * @return a RDD with each element being a zip of the prediction and error + * corresponding to each sample. + */ + def updatePredictionError( + data: RDD[LabeledPoint], + predictionAndError: RDD[(Double, Double)], + treeWeight: Double, + tree: DecisionTreeRegressionModel, + loss: OldLoss): RDD[(Double, Double)] = { + + val newPredError = data.zip(predictionAndError).mapPartitions { iter => + iter.map { case (lp, (pred, error)) => + val newPred = pred + tree.rootNode.predictImpl(lp.features).prediction * treeWeight + val newError = loss.computeError(newPred, lp.label) + (newPred, newError) + } + } + newPredError + } + + /** + * Internal method for performing regression using trees as base learners. + * @param input training dataset + * @param validationInput validation dataset, ignored if validate is set to false. + * @param boostingStrategy boosting parameters + * @param validate whether or not to use the validation dataset. + * @return tuple of ensemble models and weights: + * (array of decision tree models, array of model weights) + */ + def boost( + input: RDD[LabeledPoint], + validationInput: RDD[LabeledPoint], + boostingStrategy: OldBoostingStrategy, + validate: Boolean): (Array[DecisionTreeRegressionModel], Array[Double]) = { + val timer = new TimeTracker() + timer.start("total") + timer.start("init") + + boostingStrategy.assertValid() + + // Initialize gradient boosting parameters + val numIterations = boostingStrategy.numIterations + val baseLearners = new Array[DecisionTreeRegressionModel](numIterations) + val baseLearnerWeights = new Array[Double](numIterations) + val loss = boostingStrategy.loss + val learningRate = boostingStrategy.learningRate + // Prepare strategy for individual trees, which use regression with variance impurity. + val treeStrategy = boostingStrategy.treeStrategy.copy + val validationTol = boostingStrategy.validationTol + treeStrategy.algo = OldAlgo.Regression + treeStrategy.impurity = OldVariance + treeStrategy.assertValid() + + // Cache input + val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) { + input.persist(StorageLevel.MEMORY_AND_DISK) + true + } else { + false + } + + // Prepare periodic checkpointers + val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + + timer.stop("init") + + logDebug("##########") + logDebug("Building tree 0") + logDebug("##########") + + // Initialize tree + timer.start("building tree 0") + val firstTree = new DecisionTreeRegressor() + val firstTreeModel = firstTree.train(input, treeStrategy) + val firstTreeWeight = 1.0 + baseLearners(0) = firstTreeModel + baseLearnerWeights(0) = firstTreeWeight + + var predError: RDD[(Double, Double)] = + computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) + predErrorCheckpointer.update(predError) + logDebug("error of gbt = " + predError.values.mean()) + + // Note: A model of type regression is used since we require raw prediction + timer.stop("building tree 0") + + var validatePredError: RDD[(Double, Double)] = + computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) + if (validate) validatePredErrorCheckpointer.update(validatePredError) + var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 + var bestM = 1 + + var m = 1 + var doneLearning = false + while (m < numIterations && !doneLearning) { + // Update data with pseudo-residuals + val data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } + + timer.start(s"building tree $m") + logDebug("###################################################") + logDebug("Gradient boosting tree iteration " + m) + logDebug("###################################################") + val dt = new DecisionTreeRegressor() + val model = dt.train(data, treeStrategy) + timer.stop(s"building tree $m") + // Update partial model + baseLearners(m) = model + // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. + // Technically, the weight should be optimized for the particular loss. + // However, the behavior should be reasonable, though not optimal. + baseLearnerWeights(m) = learningRate + + predError = updatePredictionError( + input, predError, baseLearnerWeights(m), baseLearners(m), loss) + predErrorCheckpointer.update(predError) + logDebug("error of gbt = " + predError.values.mean()) + + if (validate) { + // Stop training early if + // 1. Reduction in error is less than the validationTol or + // 2. If the error increases, that is if the model is overfit. + // We want the model returned corresponding to the best validation error. + + validatePredError = updatePredictionError( + validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) + validatePredErrorCheckpointer.update(validatePredError) + val currentValidateError = validatePredError.values.mean() + if (bestValidateError - currentValidateError < validationTol * Math.max( + currentValidateError, 0.01)) { + doneLearning = true + } else if (currentValidateError < bestValidateError) { + bestValidateError = currentValidateError + bestM = m + 1 + } + } + m += 1 + } + + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + + predErrorCheckpointer.deleteAllCheckpoints() + validatePredErrorCheckpointer.deleteAllCheckpoints() + if (persistedInput) input.unpersist() + + if (validate) { + (baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM)) + } else { + (baseLearners, baseLearnerWeights) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala index f31ed2aa90a6420bc079e2b4a0c4c53d91f34e47..145dc22b7428e7194403a46f79fc36c4d0f8fc7d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala @@ -74,7 +74,7 @@ import org.apache.spark.storage.StorageLevel * * TODO: Move this out of MLlib? */ -private[mllib] class PeriodicRDDCheckpointer[T]( +private[spark] class PeriodicRDDCheckpointer[T]( checkpointInterval: Int, sc: SparkContext) extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 0b118a76733fd8c2c16357350b8bbfeec7d74ba3..d8405d13ce904a9cecf387dfcd148639212a32e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -59,7 +59,7 @@ case class BoostingStrategy @Since("1.4.0") ( * Check validity of parameters. * Throws exception if invalid. */ - private[tree] def assertValid(): Unit = { + private[spark] def assertValid(): Unit = { treeStrategy.algo match { case Classification => require(treeStrategy.numClasses == 2, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 9e3e50192d507018cdb3dfee99401ed51e60bbec..8a0907564e728f9124f9bc7b72858509583cf022 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -133,7 +133,7 @@ class Strategy @Since("1.3.0") ( * Check validity of parameters. * Throws exception if invalid. */ - private[tree] def assertValid(): Unit = { + private[spark] def assertValid(): Unit = { algo match { case Classification => require(numClasses >= 2, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index 48a4e38a346d6ff84cd273ed68ef323351c36581..9b60d018d0eda709a734c03d380b3253f49ff770 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -45,7 +45,7 @@ object AbsoluteError extends Loss { if (label - prediction < 0) 1.0 else -1.0 } - override private[mllib] def computeError(prediction: Double, label: Double): Double = { + override private[spark] def computeError(prediction: Double, label: Double): Double = { val err = label - prediction math.abs(err) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index b88743c0dbab6dd219c942a19256a08a7d4a5c64..5d92ce495b04ddd14e70f058bc35fbfb6086681a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -47,7 +47,7 @@ object LogLoss extends Loss { - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction)) } - override private[mllib] def computeError(prediction: Double, label: Double): Double = { + override private[spark] def computeError(prediction: Double, label: Double): Double = { val margin = 2.0 * label * prediction // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. 2.0 * MLUtils.log1pExp(-margin) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 687cde325ffed4653a5e1c11a2613c88e1fbc413..de14ddf024d75bfb4f0368d5b6257c90e4ea7f8d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -61,5 +61,5 @@ trait Loss extends Serializable { * @param label True label. * @return Measure of model error on datapoint. */ - private[mllib] def computeError(prediction: Double, label: Double): Double + private[spark] def computeError(prediction: Double, label: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index cb97f6fd29d950a2bd714c80a0decfb2eda4bfbb..4eb6810c46b20abd333fea22bdff5a2fb11cdb82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -45,7 +45,7 @@ object SquaredError extends Loss { - 2.0 * (label - prediction) } - override private[mllib] def computeError(prediction: Double, label: Double): Double = { + override private[spark] def computeError(prediction: Double, label: Double): Double = { val err = label - prediction err * err }