diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 837d0591478c5be4c371f2bafd90eaba57a7d6f8..0890e6263e1655d64e4c9c60192a60a21fde030c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -189,9 +189,10 @@ object DecisionTreeRunner { // Create training, test sets. val splits = if (params.testInput != "") { // Load testInput. + val numFeatures = examples.take(1)(0).features.size val origTestExamples = params.dataFormat match { case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput) + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures) } params.algo match { case Classification => { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 55f422dff0d71710662c49ff465c305462c90c1d..ce8825cc03229c05567e69277a383085341c5798 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -64,12 +64,6 @@ private[tree] class DTStatsAggregator( numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) } - /** - * Indicator for each feature of whether that feature is an unordered feature. - * TODO: Is Array[Boolean] any faster? - */ - def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex) - /** * Total number of elements stored in this aggregator */ @@ -128,21 +122,13 @@ private[tree] class DTStatsAggregator( * Pre-compute feature offset for use with [[featureUpdate]]. * For ordered features only. */ - def getFeatureOffset(featureIndex: Int): Int = { - require(!isUnordered(featureIndex), - s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" + - s" for unordered feature $featureIndex.") - featureOffsets(featureIndex) - } + def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) /** * Pre-compute feature offset for use with [[featureUpdate]]. * For unordered features only. */ def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = { - require(isUnordered(featureIndex), - s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," + - s" but was called for ordered feature $featureIndex.") val baseOffset = featureOffsets(featureIndex) (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 212dce25236e074bd7e53947e295ec0943dc3af6..772c02670e5416885f0d91305bf031757ab2dbf0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.impl import scala.collection.mutable +import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -82,7 +83,7 @@ private[tree] class DecisionTreeMetadata( } -private[tree] object DecisionTreeMetadata { +private[tree] object DecisionTreeMetadata extends Logging { /** * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. @@ -103,6 +104,10 @@ private[tree] object DecisionTreeMetadata { } val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt + if (maxPossibleBins < strategy.maxBins) { + logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" + + s" (= number of training instances)") + } // We check the number of bins here against maxPossibleBins. // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index d8476b5cd7bc757c3a94f7c690317cb749c6a1a3..004838ee5ba0e6d2fdf40dacf939e5b0f630981b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -17,12 +17,15 @@ package org.apache.spark.mllib.tree.model +import org.apache.spark.annotation.DeveloperApi + /** * Predicted value for a node * @param predict predicted value * @param prob probability of the label (classification only) */ -private[tree] class Predict( +@DeveloperApi +class Predict( val predict: Double, val prob: Double = 0.0) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala index 4d66d6d81caa5cc5f1a5fd0d4d36ffcb39f98883..6a22e2abe59bdfd61ac00db5e4f13bc347fada1b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala @@ -82,9 +82,9 @@ class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) ext */ override def toString: String = algo match { case Classification => - s"RandomForestModel classifier with $numTrees trees" + s"RandomForestModel classifier with $numTrees trees and $totalNumNodes total nodes" case Regression => - s"RandomForestModel regressor with $numTrees trees" + s"RandomForestModel regressor with $numTrees trees and $totalNumNodes total nodes" case _ => throw new IllegalArgumentException( s"RandomForestModel given unknown algo parameter: $algo.") } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 20d372dc1d3ca40346162c6a10c83100950e95c5..fb44ceb0f57ee6f1ef17ad3142cbbd130fe8bab6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -173,6 +173,22 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) } + test("alternating categorical and continuous features with multiclass labels to test indexing") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)) + arr(3) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) + val categoricalFeaturesInfo = Map(0 -> 3, 2 -> 2, 4 -> 4) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) + val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, + featureSubsetStrategy = "sqrt", seed = 12345) + RandomForestSuite.validateClassifier(model, arr, 1.0) + } + } object RandomForestSuite {