From 7bf6cc9701cbb0f77fb85a412e387fb92274fca5 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" <joseph.kurata.bradley@gmail.com> Date: Wed, 1 Oct 2014 01:03:24 -0700 Subject: [PATCH] [SPARK-3751] [mllib] DecisionTree: example update + print options DecisionTreeRunner functionality additions: * Allow user to pass in a test dataset * Do not print full model if the model is too large. As part of this, modify DecisionTreeModel and RandomForestModel to allow printing less info. Proposed updates: * toString: prints model summary * toDebugString: prints full model (named after RDD.toDebugString) Similar update to Python API: * __repr__() now prints a model summary * toDebugString() now prints the full model CC: mengxr chouqin manishamde codedeft Small update (whomever can take a look). Thanks! Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #2604 from jkbradley/dtrunner-update and squashes the following commits: b2b3c60 [Joseph K. Bradley] re-added python sql doc test, temporarily removed before 07b1fae [Joseph K. Bradley] repr() now prints a model summary toDebugString() now prints the full model 1d0d93d [Joseph K. Bradley] Updated DT and RF to print less when toString is called. Added toDebugString for verbose printing. 22eac8c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update e007a95 [Joseph K. Bradley] Updated DecisionTreeRunner to accept a test dataset. --- .../examples/mllib/DecisionTreeRunner.scala | 99 ++++++++++++++----- .../mllib/tree/model/DecisionTreeModel.scala | 14 ++- .../mllib/tree/model/RandomForestModel.scala | 30 ++++-- python/pyspark/mllib/tree.py | 10 +- 4 files changed, 111 insertions(+), 42 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 96fb068e9e..4adc91d2fb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -52,6 +52,7 @@ object DecisionTreeRunner { case class Params( input: String = null, + testInput: String = "", dataFormat: String = "libsvm", algo: Algo = Classification, maxDepth: Int = 5, @@ -98,13 +99,18 @@ object DecisionTreeRunner { s"default: ${defaultParams.featureSubsetStrategy}") .action((x, c) => c.copy(featureSubsetStrategy = x)) opt[Double]("fracTest") - .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) opt[String]("<dataFormat>") .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") .action((x, c) => c.copy(dataFormat = x)) arg[String]("<input>") - .text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)") + .text("input path to labeled examples") .required() .action((x, c) => c.copy(input = x)) checkConfig { params => @@ -141,7 +147,7 @@ object DecisionTreeRunner { case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache() } // For classification, re-index classes if needed. - val (examples, numClasses) = params.algo match { + val (examples, classIndexMap, numClasses) = params.algo match { case Classification => { // classCounts: class --> # examples in class val classCounts = origExamples.map(_.label).countByValue() @@ -170,16 +176,40 @@ object DecisionTreeRunner { val frac = classCounts(c) / numExamples.toDouble println(s"$c\t$frac\t${classCounts(c)}") } - (examples, numClasses) + (examples, classIndexMap, numClasses) } case Regression => - (origExamples, 0) + (origExamples, null, 0) case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") } - // Split into training, test. - val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) + // Create training, test sets. + val splits = if (params.testInput != "") { + // Load testInput. + val origTestExamples = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput) + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput) + } + params.algo match { + case Classification => { + // classCounts: class --> # examples in class + val testExamples = { + if (classIndexMap.isEmpty) { + origTestExamples + } else { + origTestExamples.map(lp => LabeledPoint(classIndexMap(lp.label), lp.features)) + } + } + Array(examples, testExamples) + } + case Regression => + Array(examples, origTestExamples) + } + } else { + // Split input into training, test. + examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) + } val training = splits(0).cache() val test = splits(1).cache() val numTraining = training.count() @@ -206,47 +236,62 @@ object DecisionTreeRunner { minInfoGain = params.minInfoGain) if (params.numTrees == 1) { val model = DecisionTree.train(training, strategy) - println(model) + if (model.numNodes < 20) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } if (params.algo == Classification) { - val accuracy = + val trainAccuracy = + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) + .precision + println(s"Train accuracy = $trainAccuracy") + val testAccuracy = new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision - println(s"Test accuracy = $accuracy") + println(s"Test accuracy = $testAccuracy") } if (params.algo == Regression) { - val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse") + val trainMSE = meanSquaredError(model, training) + println(s"Train mean squared error = $trainMSE") + val testMSE = meanSquaredError(model, test) + println(s"Test mean squared error = $testMSE") } } else { val randomSeed = Utils.random.nextInt() if (params.algo == Classification) { val model = RandomForest.trainClassifier(training, strategy, params.numTrees, params.featureSubsetStrategy, randomSeed) - println(model) - val accuracy = + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainAccuracy = + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) + .precision + println(s"Train accuracy = $trainAccuracy") + val testAccuracy = new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision - println(s"Test accuracy = $accuracy") + println(s"Test accuracy = $testAccuracy") } if (params.algo == Regression) { val model = RandomForest.trainRegressor(training, strategy, params.numTrees, params.featureSubsetStrategy, randomSeed) - println(model) - val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainMSE = meanSquaredError(model, training) + println(s"Train mean squared error = $trainMSE") + val testMSE = meanSquaredError(model, test) + println(s"Test mean squared error = $testMSE") } } sc.stop() } - /** - * Calculates the classifier accuracy. - */ - private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - val correctCount = data.filter(y => model.predict(y.features) == y.label).count() - val count = data.count() - correctCount.toDouble / count - } - /** * Calculates the mean squared error for regression. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 271b2c4ad8..ec1d99ab26 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -68,15 +68,23 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable } /** - * Print full model. + * Print a summary of the model. */ override def toString: String = algo match { case Classification => - s"DecisionTreeModel classifier\n" + topNode.subtreeToString(2) + s"DecisionTreeModel classifier of depth $depth with $numNodes nodes" case Regression => - s"DecisionTreeModel regressor\n" + topNode.subtreeToString(2) + s"DecisionTreeModel regressor of depth $depth with $numNodes nodes" case _ => throw new IllegalArgumentException( s"DecisionTreeModel given unknown algo parameter: $algo.") } + /** + * Print the full model to a string. + */ + def toDebugString: String = { + val header = toString + "\n" + header + topNode.subtreeToString(2) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala index 538c0e2332..4d66d6d81c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala @@ -73,17 +73,27 @@ class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) ext def numTrees: Int = trees.size /** - * Print full model. + * Get total number of nodes, summed over all trees in the forest. */ - override def toString: String = { - val header = algo match { - case Classification => - s"RandomForestModel classifier with $numTrees trees\n" - case Regression => - s"RandomForestModel regressor with $numTrees trees\n" - case _ => throw new IllegalArgumentException( - s"RandomForestModel given unknown algo parameter: $algo.") - } + def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum + + /** + * Print a summary of the model. + */ + override def toString: String = algo match { + case Classification => + s"RandomForestModel classifier with $numTrees trees" + case Regression => + s"RandomForestModel regressor with $numTrees trees" + case _ => throw new IllegalArgumentException( + s"RandomForestModel given unknown algo parameter: $algo.") + } + + /** + * Print the full model to a string. + */ + def toDebugString: String = { + val header = toString + "\n" header + trees.zipWithIndex.map { case (tree, treeIndex) => s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) }.fold("")(_ + _) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index f59a818a6e..afdcdbdf3a 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -77,8 +77,13 @@ class DecisionTreeModel(object): return self._java_model.depth() def __repr__(self): + """ Print summary of model. """ return self._java_model.toString() + def toDebugString(self): + """ Print full model. """ + return self._java_model.toDebugString() + class DecisionTree(object): @@ -135,7 +140,6 @@ class DecisionTree(object): >>> from numpy import array >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import DecisionTree - >>> from pyspark.mllib.linalg import SparseVector >>> >>> data = [ ... LabeledPoint(0.0, [0.0]), @@ -145,7 +149,9 @@ class DecisionTree(object): ... ] >>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {}) >>> print model, # it already has newline - DecisionTreeModel classifier + DecisionTreeModel classifier of depth 1 with 3 nodes + >>> print model.toDebugString(), # it already has newline + DecisionTreeModel classifier of depth 1 with 3 nodes If (feature 0 <= 0.5) Predict: 0.0 Else (feature 0 > 0.5) -- GitLab