From 69f5a7c934ac553ed52c00679b800bcffe83c1d6 Mon Sep 17 00:00:00 2001
From: "Joseph K. Bradley" <joseph@databricks.com>
Date: Mon, 3 Aug 2015 10:46:34 -0700
Subject: [PATCH] [SPARK-9528] [ML] Changed RandomForestClassifier to extend
 ProbabilisticClassifier

RandomForestClassifier now outputs rawPrediction based on tree probabilities, plus probability column computed from normalized rawPrediction.

CC: holdenk

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #7859 from jkbradley/rf-prob and squashes the following commits:

6c28f51 [Joseph K. Bradley] Changed RandomForestClassifier to extend ProbabilisticClassifier
---
 .../DecisionTreeClassifier.scala              |  8 +---
 .../ProbabilisticClassifier.scala             | 27 +++++++++++++-
 .../RandomForestClassifier.scala              | 37 +++++++++++++------
 .../RandomForestClassifierSuite.scala         | 36 ++++++++++++++----
 4 files changed, 81 insertions(+), 27 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index f27cfd0331..f2b992f8ba 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -131,13 +131,7 @@ final class DecisionTreeClassificationModel private[ml] (
   override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
     rawPrediction match {
       case dv: DenseVector =>
-        var i = 0
-        val size = dv.size
-        val sum = dv.values.sum
-        while (i < size) {
-          dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0
-          i += 1
-        }
+        ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv)
         dv
       case sv: SparseVector =>
         throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index dad4511086..f9c9c2371f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT}
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DoubleType, DataType, StructType}
@@ -175,3 +175,28 @@ private[spark] abstract class ProbabilisticClassificationModel[
    */
   protected def probability2prediction(probability: Vector): Double = probability.argmax
 }
+
+private[ml] object ProbabilisticClassificationModel {
+
+  /**
+   * Normalize a vector of raw predictions to be a multinomial probability vector, in place.
+   *
+   * The input raw predictions should be >= 0.
+   * The output vector sums to 1, unless the input vector is all-0 (in which case the output is
+   * all-0 too).
+   *
+   * NOTE: This is NOT applicable to all models, only ones which effectively use class
+   *       instance counts for raw predictions.
+   */
+  def normalizeToProbabilitiesInPlace(v: DenseVector): Unit = {
+    val sum = v.values.sum
+    if (sum != 0) {
+      var i = 0
+      val size = v.size
+      while (i < size) {
+        v.values(i) /= sum
+        i += 1
+      }
+    }
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 0c7eb4a662..56e80cc8fe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -17,22 +17,19 @@
 
 package org.apache.spark.ml.classification
 
-import scala.collection.mutable
-
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
 import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DoubleType
+
 
 /**
  * :: Experimental ::
@@ -43,7 +40,7 @@ import org.apache.spark.sql.types.DoubleType
  */
 @Experimental
 final class RandomForestClassifier(override val uid: String)
-  extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
+  extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
   with RandomForestParams with TreeClassifierParams {
 
   def this() = this(Identifiable.randomUID("rfc"))
@@ -127,7 +124,7 @@ final class RandomForestClassificationModel private[ml] (
     override val uid: String,
     private val _trees: Array[DecisionTreeClassificationModel],
     override val numClasses: Int)
-  extends ClassificationModel[Vector, RandomForestClassificationModel]
+  extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
   with TreeEnsembleModel with Serializable {
 
   require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
@@ -157,15 +154,33 @@ final class RandomForestClassificationModel private[ml] (
   override protected def predictRaw(features: Vector): Vector = {
     // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
     // Classifies using majority votes.
-    // Ignore the weights since all are 1.0 for now.
-    val votes = new Array[Double](numClasses)
+    // Ignore the tree weights since all are 1.0 for now.
+    val votes = Array.fill[Double](numClasses)(0.0)
     _trees.view.foreach { tree =>
-      val prediction = tree.rootNode.predictImpl(features).prediction.toInt
-      votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
+      val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats
+      val total = classCounts.sum
+      if (total != 0) {
+        var i = 0
+        while (i < numClasses) {
+          votes(i) += classCounts(i) / total
+          i += 1
+        }
+      }
     }
     Vectors.dense(votes)
   }
 
+  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+    rawPrediction match {
+      case dv: DenseVector =>
+        ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv)
+        dv
+      case sv: SparseVector =>
+        throw new RuntimeException("Unexpected error in RandomForestClassificationModel:" +
+          " raw2probabilityInPlace encountered SparseVector")
+    }
+  }
+
   override def copy(extra: ParamMap): RandomForestClassificationModel = {
     copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra)
   }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index dbb2577c62..edf848b21a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
 
@@ -121,6 +122,33 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
     compareAPIs(rdd, rf2, categoricalFeatures, numClasses)
   }
 
+  test("predictRaw and predictProbability") {
+    val rdd = orderedLabeledPoints5_20
+    val rf = new RandomForestClassifier()
+      .setImpurity("Gini")
+      .setMaxDepth(3)
+      .setNumTrees(3)
+      .setSeed(123)
+    val categoricalFeatures = Map.empty[Int, Int]
+    val numClasses = 2
+
+    val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+    val model = rf.fit(df)
+
+    val predictions = model.transform(df)
+      .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol)
+      .collect()
+
+    predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
+      assert(pred === rawPred.argmax,
+        s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
+      val sum = rawPred.toArray.sum
+      assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
+        "probability prediction mismatch")
+      assert(probPred.toArray.sum ~== 1.0 relTol 1E-5)
+    }
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
@@ -173,13 +201,5 @@ private object RandomForestClassifierSuite {
     assert(newModel.hasParent)
     assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
     assert(newModel.numClasses == numClasses)
-    val results = newModel.transform(newData)
-    results.select("rawPrediction", "prediction").collect().foreach {
-      case Row(raw: Vector, prediction: Double) => {
-        assert(raw.size == numClasses)
-        val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2
-        assert(predFromRaw == prediction)
-      }
-    }
   }
 }
-- 
GitLab