From 7ecf0c46990c39df8aeddbd64ca33d01824bcc0a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" <joseph@databricks.com> Date: Fri, 14 Aug 2015 10:48:02 -0700 Subject: [PATCH] [SPARK-9956] [ML] Make trees work with one-category features This modifies DecisionTreeMetadata construction to treat 1-category features as continuous, so that trees do not fail with such features. It is important for the pipelines API, where VectorIndexer can automatically categorize certain features as categorical. As stated in the JIRA, this is a temp fix which we can improve upon later by automatically filtering out those features. That will take longer, though, since it will require careful indexing. Targeted for 1.5 and master CC: manishamde mengxr yanboliang Author: Joseph K. Bradley <joseph@databricks.com> Closes #8187 from jkbradley/tree-1cat. --- .../tree/impl/DecisionTreeMetadata.scala | 27 ++++++++++++------- .../DecisionTreeClassifierSuite.scala | 13 +++++++++ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 9fe264656e..21ee49c457 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -144,21 +144,28 @@ private[spark] object DecisionTreeMetadata extends Logging { val maxCategoriesForUnorderedFeature = ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // Decide if some categorical features should be treated as unordered features, - // which require 2 * ((1 << numCategories - 1) - 1) bins. - // We do this check with log values to prevent overflows in case numCategories is large. - // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins - if (numCategories <= maxCategoriesForUnorderedFeature) { - unorderedFeatures.add(featureIndex) - numBins(featureIndex) = numUnorderedBins(numCategories) - } else { - numBins(featureIndex) = numCategories + // Hack: If a categorical feature has only 1 category, we treat it as continuous. + // TODO(SPARK-9957): Handle this properly by filtering out those features. + if (numCategories > 1) { + // Decide if some categorical features should be treated as unordered features, + // which require 2 * ((1 << numCategories - 1) - 1) bins. + // We do this check with log values to prevent overflows in case numCategories is large. + // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + if (numCategories <= maxCategoriesForUnorderedFeature) { + unorderedFeatures.add(featureIndex) + numBins(featureIndex) = numUnorderedBins(numCategories) + } else { + numBins(featureIndex) = numCategories + } } } } else { // Binary classification or regression strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - numBins(featureIndex) = numCategories + // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957 + if (numCategories > 1) { + numBins(featureIndex) = numCategories + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 4b7c5d3f23..f680d8d3c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -261,6 +261,19 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte } } + test("training with 1-category categorical feature") { + val data = sc.parallelize(Seq( + LabeledPoint(0, Vectors.dense(0, 2, 3)), + LabeledPoint(1, Vectors.dense(0, 3, 1)), + LabeledPoint(0, Vectors.dense(0, 2, 2)), + LabeledPoint(1, Vectors.dense(0, 3, 9)), + LabeledPoint(0, Vectors.dense(0, 2, 6)) + )) + val df = TreeTests.setMetadata(data, Map(0 -> 1), 2) + val dt = new DecisionTreeClassifier().setMaxDepth(3) + val model = dt.fit(df) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// -- GitLab