Skip to content
Snippets Groups Projects
Commit 37c2d192 authored by Holden Karau's avatar Holden Karau Committed by Joseph K. Bradley
Browse files

[SPARK-9016] [ML] make random forest classifiers implement classification trait

Implement the classification trait for RandomForestClassifiers. The plan is to use this in the future to providing thresholding for RandomForestClassifiers (as well as other classifiers that implement that trait).

Author: Holden Karau <holden@pigscanfly.ca>

Closes #7432 from holdenk/SPARK-9016-make-random-forest-classifiers-implement-classification-trait and squashes the following commits:

bf22fa6 [Holden Karau] Add missing imports for testing suite
e948f0d [Holden Karau] Check the prediction generation from rawprediciton
25320c3 [Holden Karau] Don't supply numClasses when not needed, assert model classes are as expected
1a67e04 [Holden Karau] Use old decission tree stuff instead
673e0c3 [Holden Karau] Merge branch 'master' into SPARK-9016-make-random-forest-classifiers-implement-classification-trait
0d15b96 [Holden Karau] FIx typo
5eafad4 [Holden Karau] add a constructor for rootnode + num classes
fc6156f [Holden Karau] scala style fix
2597915 [Holden Karau] take num classes in constructor
3ccfe4a [Holden Karau] Merge in master, make pass numClasses through randomforest for training
222a10b [Holden Karau] Increase numtrees to 3 in the python test since before the two were equal and the argmax was selecting the last one
16aea1c [Holden Karau] Make tests match the new models
b454a02 [Holden Karau] Make the Tree classifiers extends the Classifier base class
77b4114 [Holden Karau] Import vectors lib
parent 103d8cce
No related branches found
No related tags found
No related merge requests found
......@@ -25,7 +25,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
......@@ -43,7 +43,7 @@ import org.apache.spark.sql.types.DoubleType
*/
@Experimental
final class RandomForestClassifier(override val uid: String)
extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel]
extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams {
def this() = this(Identifiable.randomUID("rfc"))
......@@ -98,7 +98,7 @@ final class RandomForestClassifier(override val uid: String)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeClassificationModel])
new RandomForestClassificationModel(trees)
new RandomForestClassificationModel(trees, numClasses)
}
override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
......@@ -125,8 +125,9 @@ object RandomForestClassifier {
@Experimental
final class RandomForestClassificationModel private[ml] (
override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel])
extends PredictionModel[Vector, RandomForestClassificationModel]
private val _trees: Array[DecisionTreeClassificationModel],
override val numClasses: Int)
extends ClassificationModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
......@@ -135,8 +136,8 @@ final class RandomForestClassificationModel private[ml] (
* Construct a random forest classification model, with all trees weighted equally.
* @param trees Component trees
*/
def this(trees: Array[DecisionTreeClassificationModel]) =
this(Identifiable.randomUID("rfc"), trees)
def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) =
this(Identifiable.randomUID("rfc"), trees, numClasses)
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
......@@ -153,20 +154,20 @@ final class RandomForestClassificationModel private[ml] (
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
override protected def predict(features: Vector): Double = {
override protected def predictRaw(features: Vector): Vector = {
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
// Classifies using majority votes.
// Ignore the weights since all are 1.0 for now.
val votes = mutable.Map.empty[Int, Double]
val votes = new Array[Double](numClasses)
_trees.view.foreach { tree =>
val prediction = tree.rootNode.predict(features).toInt
votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight
votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
}
votes.maxBy(_._2)._1
Vectors.dense(votes)
}
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees), extra)
copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra)
}
override def toString: String = {
......@@ -185,7 +186,8 @@ private[ml] object RandomForestClassificationModel {
def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestClassifier,
categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = {
categoricalFeatures: Map[Int, Int],
numClasses: Int): RandomForestClassificationModel = {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
......@@ -193,6 +195,6 @@ private[ml] object RandomForestClassificationModel {
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
new RandomForestClassificationModel(uid, newTrees)
new RandomForestClassificationModel(uid, newTrees, numClasses)
}
}
......@@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Row}
/**
* Test suite for [[RandomForestClassifier]].
......@@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))))
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2)
ParamsSuite.checkParams(model)
}
......@@ -167,9 +167,19 @@ private object RandomForestClassifierSuite {
val newModel = rf.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = RandomForestClassificationModel.fromOld(
oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures,
numClasses)
TreeTests.checkEqual(oldModelAsNew, newModel)
assert(newModel.hasParent)
assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
assert(newModel.numClasses == numClasses)
val results = newModel.transform(newData)
results.select("rawPrediction", "prediction").collect().foreach {
case Row(raw: Vector, prediction: Double) => {
assert(raw.size == numClasses)
val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2
assert(predFromRaw == prediction)
}
}
}
}
......@@ -299,9 +299,9 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
>>> si_model = stringIndexer.fit(df)
>>> td = si_model.transform(df)
>>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
>>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42)
>>> model = rf.fit(td)
>>> allclose(model.treeWeights, [1.0, 1.0])
>>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment