From 69f5a7c934ac553ed52c00679b800bcffe83c1d6 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" <joseph@databricks.com> Date: Mon, 3 Aug 2015 10:46:34 -0700 Subject: [PATCH] [SPARK-9528] [ML] Changed RandomForestClassifier to extend ProbabilisticClassifier RandomForestClassifier now outputs rawPrediction based on tree probabilities, plus probability column computed from normalized rawPrediction. CC: holdenk Author: Joseph K. Bradley <joseph@databricks.com> Closes #7859 from jkbradley/rf-prob and squashes the following commits: 6c28f51 [Joseph K. Bradley] Changed RandomForestClassifier to extend ProbabilisticClassifier --- .../DecisionTreeClassifier.scala | 8 +--- .../ProbabilisticClassifier.scala | 27 +++++++++++++- .../RandomForestClassifier.scala | 37 +++++++++++++------ .../RandomForestClassifierSuite.scala | 36 ++++++++++++++---- 4 files changed, 81 insertions(+), 27 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index f27cfd0331..f2b992f8ba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -131,13 +131,7 @@ final class DecisionTreeClassificationModel private[ml] ( override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { rawPrediction match { case dv: DenseVector => - var i = 0 - val size = dv.size - val sum = dv.values.sum - while (i < size) { - dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0 - i += 1 - } + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv) dv case sv: SparseVector => throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index dad4511086..f9c9c2371f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, DataType, StructType} @@ -175,3 +175,28 @@ private[spark] abstract class ProbabilisticClassificationModel[ */ protected def probability2prediction(probability: Vector): Double = probability.argmax } + +private[ml] object ProbabilisticClassificationModel { + + /** + * Normalize a vector of raw predictions to be a multinomial probability vector, in place. + * + * The input raw predictions should be >= 0. + * The output vector sums to 1, unless the input vector is all-0 (in which case the output is + * all-0 too). + * + * NOTE: This is NOT applicable to all models, only ones which effectively use class + * instance counts for raw predictions. + */ + def normalizeToProbabilitiesInPlace(v: DenseVector): Unit = { + val sum = v.values.sum + if (sum != 0) { + var i = 0 + val size = v.size + while (i < size) { + v.values(i) /= sum + i += 1 + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 0c7eb4a662..56e80cc8fe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -17,22 +17,19 @@ package org.apache.spark.ml.classification -import scala.collection.mutable - import org.apache.spark.annotation.Experimental import org.apache.spark.ml.tree.impl.RandomForest -import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType + /** * :: Experimental :: @@ -43,7 +40,7 @@ import org.apache.spark.sql.types.DoubleType */ @Experimental final class RandomForestClassifier(override val uid: String) - extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel] + extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { def this() = this(Identifiable.randomUID("rfc")) @@ -127,7 +124,7 @@ final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], override val numClasses: Int) - extends ClassificationModel[Vector, RandomForestClassificationModel] + extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") @@ -157,15 +154,33 @@ final class RandomForestClassificationModel private[ml] ( override protected def predictRaw(features: Vector): Vector = { // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 // Classifies using majority votes. - // Ignore the weights since all are 1.0 for now. - val votes = new Array[Double](numClasses) + // Ignore the tree weights since all are 1.0 for now. + val votes = Array.fill[Double](numClasses)(0.0) _trees.view.foreach { tree => - val prediction = tree.rootNode.predictImpl(features).prediction.toInt - votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight + val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats + val total = classCounts.sum + if (total != 0) { + var i = 0 + while (i < numClasses) { + votes(i) += classCounts(i) / total + i += 1 + } + } } Vectors.dense(votes) } + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv) + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in RandomForestClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } + } + override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index dbb2577c62..edf848b21a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} @@ -121,6 +122,33 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, rf2, categoricalFeatures, numClasses) } + test("predictRaw and predictProbability") { + val rdd = orderedLabeledPoints5_20 + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setSeed(123) + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val model = rf.fit(df) + + val predictions = model.transform(df) + .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) + .collect() + + predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") + assert(probPred.toArray.sum ~== 1.0 relTol 1E-5) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// @@ -173,13 +201,5 @@ private object RandomForestClassifierSuite { assert(newModel.hasParent) assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) assert(newModel.numClasses == numClasses) - val results = newModel.transform(newData) - results.select("rawPrediction", "prediction").collect().foreach { - case Row(raw: Vector, prediction: Double) => { - assert(raw.size == numClasses) - val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2 - assert(predFromRaw == prediction) - } - } } } -- GitLab