Skip to content
Snippets Groups Projects
Commit fe409f31 authored by Ilya Matiach's avatar Ilya Matiach Committed by Joseph K. Bradley
Browse files

[SPARK-14975][ML] Fixed GBTClassifier to predict probability per training...

[SPARK-14975][ML] Fixed GBTClassifier to predict probability per training instance and fixed interfaces

## What changes were proposed in this pull request?

For all of the classifiers in MLLib we can predict probabilities except for GBTClassifier.
Also, all classifiers inherit from ProbabilisticClassifier but GBTClassifier strangely inherits from Predictor, which is a bug.
This change corrects the interface and adds the ability for the classifier to give a probabilities vector.

## How was this patch tested?

The basic ML tests were run after making the changes.  I've marked this as WIP as I need to add more tests.

Author: Ilya Matiach <ilmat@microsoft.com>

Closes #16441 from imatiach-msft/ilmat/fix-GBT.
parent a81e336f
No related branches found
No related tags found
No related merge requests found
......@@ -23,9 +23,8 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
......@@ -33,6 +32,7 @@ import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.loss.LogLoss
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
......@@ -58,7 +58,7 @@ import org.apache.spark.sql.functions._
@Since("1.4.0")
class GBTClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
extends ProbabilisticClassifier[Vector, GBTClassifier, GBTClassificationModel]
with GBTClassifierParams with DefaultParamsWritable with Logging {
@Since("1.4.0")
......@@ -158,12 +158,19 @@ class GBTClassifier @Since("1.4.0") (
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val numClasses = 2
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
val instr = Instrumentation.create(this, oldDataset)
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
instr.logNumFeatures(numFeatures)
instr.logNumClasses(2)
instr.logNumClasses(numClasses)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
$(seed))
......@@ -202,8 +209,9 @@ class GBTClassificationModel private[ml](
@Since("1.6.0") override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double],
@Since("1.6.0") override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
@Since("1.6.0") override val numFeatures: Int,
@Since("2.2.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, GBTClassificationModel]
with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable {
......@@ -211,6 +219,20 @@ class GBTClassificationModel private[ml](
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
/**
* Construct a GBTClassificationModel
*
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
* @param numFeatures The number of features.
*/
private[ml] def this(
uid: String,
_trees: Array[DecisionTreeRegressionModel],
_treeWeights: Array[Double],
numFeatures: Int) =
this(uid, _trees, _treeWeights, numFeatures, 2)
/**
* Construct a GBTClassificationModel
*
......@@ -219,7 +241,7 @@ class GBTClassificationModel private[ml](
*/
@Since("1.6.0")
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
this(uid, _trees, _treeWeights, -1)
this(uid, _trees, _treeWeights, -1, 2)
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees
......@@ -242,11 +264,29 @@ class GBTClassificationModel private[ml](
}
override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
if (prediction > 0.0) 1.0 else 0.0
// If thresholds defined, use predictRaw to get probabilities, otherwise use optimization
if (isDefined(thresholds)) {
super.predict(features)
} else {
if (margin(features) > 0.0) 1.0 else 0.0
}
}
override protected def predictRaw(features: Vector): Vector = {
val prediction: Double = margin(features)
Vectors.dense(Array(-prediction, prediction))
}
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
rawPrediction match {
case dv: DenseVector =>
dv.values(0) = loss.computeProbability(dv.values(0))
dv.values(1) = 1.0 - dv.values(0)
dv
case sv: SparseVector =>
throw new RuntimeException("Unexpected error in GBTClassificationModel:" +
" raw2probabilityInPlace encountered SparseVector")
}
}
/** Number of trees in ensemble */
......@@ -254,7 +294,7 @@ class GBTClassificationModel private[ml](
@Since("1.4.0")
override def copy(extra: ParamMap): GBTClassificationModel = {
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses),
extra).setParent(parent)
}
......@@ -276,11 +316,20 @@ class GBTClassificationModel private[ml](
@Since("2.0.0")
lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
/** Raw prediction for the positive class. */
private def margin(features: Vector): Double = {
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
}
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
}
// hard coded loss, which is not meant to be changed in the model
private val loss = getOldLossType
@Since("2.0.0")
override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
}
......@@ -288,6 +337,9 @@ class GBTClassificationModel private[ml](
@Since("2.0.0")
object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
private val numFeaturesKey: String = "numFeatures"
private val numTreesKey: String = "numTrees"
@Since("2.0.0")
override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader
......@@ -300,8 +352,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
override protected def saveImpl(path: String): Unit = {
val extraMetadata: JObject = Map(
"numFeatures" -> instance.numFeatures,
"numTrees" -> instance.getNumTrees)
numFeaturesKey -> instance.numFeatures,
numTreesKey -> instance.getNumTrees)
EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
}
}
......@@ -316,8 +368,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
val numTrees = (metadata.metadata \ numTreesKey).extract[Int]
val trees: Array[DecisionTreeRegressionModel] = treesData.map {
case (treeMetadata, root) =>
......@@ -328,7 +380,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
}
require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" +
s" trees based on metadata but found ${trees.length} trees.")
val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures)
val model = new GBTClassificationModel(metadata.uid,
trees, treeWeights, numFeatures)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
......@@ -339,7 +392,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
oldModel: OldGBTModel,
parent: GBTClassifier,
categoricalFeatures: Map[Int, Int],
numFeatures: Int = -1): GBTClassificationModel = {
numFeatures: Int = -1,
numClasses: Int = 2): GBTClassificationModel = {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
......@@ -347,6 +401,6 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures)
new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures, numClasses)
}
}
......@@ -25,7 +25,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, ClassificationLoss => OldClassificationLoss, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
......@@ -531,7 +531,7 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam
def getLossType: String = $(lossType).toLowerCase
/** (private[ml]) Convert new loss to old loss. */
override private[ml] def getOldLossType: OldLoss = {
override private[ml] def getOldLossType: OldClassificationLoss = {
getLossType match {
case "logistic" => OldLogLoss
case _ =>
......
......@@ -20,7 +20,6 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.mllib.util.MLUtils
/**
* :: DeveloperApi ::
* Class for log loss calculation (for classification).
......@@ -32,7 +31,7 @@ import org.apache.spark.mllib.util.MLUtils
*/
@Since("1.2.0")
@DeveloperApi
object LogLoss extends Loss {
object LogLoss extends ClassificationLoss {
/**
* Method to calculate the loss gradients for the gradient boosting calculation for binary
......@@ -52,4 +51,11 @@ object LogLoss extends Loss {
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}
/**
* Returns the estimated probability of a label of 1.0.
*/
override private[spark] def computeProbability(margin: Double): Double = {
1.0 / (1.0 + math.exp(-2.0 * margin))
}
}
......@@ -22,7 +22,6 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD
/**
* :: DeveloperApi ::
* Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
......@@ -67,3 +66,10 @@ trait Loss extends Serializable {
*/
private[spark] def computeError(prediction: Double, label: Double): Double
}
private[spark] trait ClassificationLoss extends Loss {
/**
* Computes the class probability given the margin.
*/
private[spark] def computeProbability(margin: Double): Double
}
......@@ -17,20 +17,24 @@
package org.apache.spark.ml.classification
import com.github.fommil.netlib.BLAS
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.loss.LogLoss
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.util.Utils
/**
......@@ -49,6 +53,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
private var data: RDD[LabeledPoint] = _
private var trainData: RDD[LabeledPoint] = _
private var validationData: RDD[LabeledPoint] = _
private val eps: Double = 1e-5
private val absEps: Double = 1e-8
override def beforeAll() {
super.beforeAll()
......@@ -66,10 +72,156 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
ParamsSuite.checkParams(new GBTClassifier)
val model = new GBTClassificationModel("gbtc",
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)),
Array(1.0), 1)
Array(1.0), 1, 2)
ParamsSuite.checkParams(model)
}
test("GBTClassifier: default params") {
val gbt = new GBTClassifier
assert(gbt.getLabelCol === "label")
assert(gbt.getFeaturesCol === "features")
assert(gbt.getPredictionCol === "prediction")
assert(gbt.getRawPredictionCol === "rawPrediction")
assert(gbt.getProbabilityCol === "probability")
val df = trainData.toDF()
val model = gbt.fit(df)
model.transform(df)
.select("label", "probability", "prediction", "rawPrediction")
.collect()
intercept[NoSuchElementException] {
model.getThresholds
}
assert(model.getFeaturesCol === "features")
assert(model.getPredictionCol === "prediction")
assert(model.getRawPredictionCol === "rawPrediction")
assert(model.getProbabilityCol === "probability")
assert(model.hasParent)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
}
test("setThreshold, getThreshold") {
val gbt = new GBTClassifier
// default
withClue("GBTClassifier should not have thresholds set by default.") {
intercept[NoSuchElementException] {
gbt.getThresholds
}
}
// Set via thresholds
val gbt2 = new GBTClassifier
val threshold = Array(0.3, 0.7)
gbt2.setThresholds(threshold)
assert(gbt2.getThresholds === threshold)
}
test("thresholds prediction") {
val gbt = new GBTClassifier
val df = trainData.toDF()
val binaryModel = gbt.fit(df)
// should predict all zeros
binaryModel.setThresholds(Array(0.0, 1.0))
val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect()
assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0))
// should predict all ones
binaryModel.setThresholds(Array(1.0, 0.0))
val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect()
assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0))
val gbtBase = new GBTClassifier
val model = gbtBase.fit(df)
val basePredictions = model.transform(df).select("prediction").collect()
// constant threshold scaling is the same as no thresholds
binaryModel.setThresholds(Array(1.0, 1.0))
val scaledPredictions = binaryModel.transform(df).select("prediction").collect()
assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
scaled.getDouble(0) === base.getDouble(0)
})
// force it to use the predict method
model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1))
val predictionsWithPredict = model.transform(df).select("prediction").collect()
assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0))
}
test("GBTClassifier: Predictor, Classifier methods") {
val rawPredictionCol = "rawPrediction"
val predictionCol = "prediction"
val labelCol = "label"
val featuresCol = "features"
val probabilityCol = "probability"
val gbt = new GBTClassifier().setSeed(123)
val trainingDataset = trainData.toDF(labelCol, featuresCol)
val gbtModel = gbt.fit(trainingDataset)
assert(gbtModel.numClasses === 2)
val numFeatures = trainingDataset.select(featuresCol).first().getAs[Vector](0).size
assert(gbtModel.numFeatures === numFeatures)
val blas = BLAS.getInstance()
val validationDataset = validationData.toDF(labelCol, featuresCol)
val results = gbtModel.transform(validationDataset)
// check that raw prediction is tree predictions dot tree weights
results.select(rawPredictionCol, featuresCol).collect().foreach {
case Row(raw: Vector, features: Vector) =>
assert(raw.size === 2)
val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction)
val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1)
assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps)
}
// Compare rawPrediction with probability
results.select(rawPredictionCol, probabilityCol).collect().foreach {
case Row(raw: Vector, prob: Vector) =>
assert(raw.size === 2)
assert(prob.size === 2)
// Note: we should check other loss types for classification if they are added
val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value))
assert(prob(0) ~== predFromRaw(0) relTol eps)
assert(prob(1) ~== predFromRaw(1) relTol eps)
assert(prob(0) + prob(1) ~== 1.0 absTol absEps)
}
// Compare prediction with probability
results.select(predictionCol, probabilityCol).collect().foreach {
case Row(pred: Double, prob: Vector) =>
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
assert(pred == predFromProb)
}
// force it to use raw2prediction
gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("")
val resultsUsingRaw2Predict =
gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}
// force it to use probability2prediction
gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol)
val resultsUsingProb2Predict =
gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}
// force it to use predict
gbtModel.setRawPredictionCol("").setProbabilityCol("")
val resultsUsingPredict =
gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}
}
test("GBT parameter stepSize should be in interval (0, 1]") {
withClue("GBT parameter stepSize should be in interval (0, 1]") {
intercept[IllegalArgumentException] {
......@@ -246,7 +398,8 @@ private object GBTClassifierSuite extends SparkFunSuite {
val newModel = gbt.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = GBTClassificationModel.fromOld(
oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures, numFeatures)
oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures,
numFeatures, numClasses = 2)
TreeTests.checkEqual(oldModelAsNew, newModel)
assert(newModel.numFeatures === numFeatures)
assert(oldModelAsNew.numFeatures === numFeatures)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment