From 7ecf0c46990c39df8aeddbd64ca33d01824bcc0a Mon Sep 17 00:00:00 2001
From: "Joseph K. Bradley" <joseph@databricks.com>
Date: Fri, 14 Aug 2015 10:48:02 -0700
Subject: [PATCH] [SPARK-9956] [ML] Make trees work with one-category features

This modifies DecisionTreeMetadata construction to treat 1-category features as continuous, so that trees do not fail with such features.  It is important for the pipelines API, where VectorIndexer can automatically categorize certain features as categorical.

As stated in the JIRA, this is a temp fix which we can improve upon later by automatically filtering out those features. That will take longer, though, since it will require careful indexing.

Targeted for 1.5 and master

CC: manishamde  mengxr yanboliang

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #8187 from jkbradley/tree-1cat.
---
 .../tree/impl/DecisionTreeMetadata.scala      | 27 ++++++++++++-------
 .../DecisionTreeClassifierSuite.scala         | 13 +++++++++
 2 files changed, 30 insertions(+), 10 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 9fe264656e..21ee49c457 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -144,21 +144,28 @@ private[spark] object DecisionTreeMetadata extends Logging {
       val maxCategoriesForUnorderedFeature =
         ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
       strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
-        // Decide if some categorical features should be treated as unordered features,
-        //  which require 2 * ((1 << numCategories - 1) - 1) bins.
-        // We do this check with log values to prevent overflows in case numCategories is large.
-        // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
-        if (numCategories <= maxCategoriesForUnorderedFeature) {
-          unorderedFeatures.add(featureIndex)
-          numBins(featureIndex) = numUnorderedBins(numCategories)
-        } else {
-          numBins(featureIndex) = numCategories
+        // Hack: If a categorical feature has only 1 category, we treat it as continuous.
+        // TODO(SPARK-9957): Handle this properly by filtering out those features.
+        if (numCategories > 1) {
+          // Decide if some categorical features should be treated as unordered features,
+          //  which require 2 * ((1 << numCategories - 1) - 1) bins.
+          // We do this check with log values to prevent overflows in case numCategories is large.
+          // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
+          if (numCategories <= maxCategoriesForUnorderedFeature) {
+            unorderedFeatures.add(featureIndex)
+            numBins(featureIndex) = numUnorderedBins(numCategories)
+          } else {
+            numBins(featureIndex) = numCategories
+          }
         }
       }
     } else {
       // Binary classification or regression
       strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
-        numBins(featureIndex) = numCategories
+        // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957
+        if (numCategories > 1) {
+          numBins(featureIndex) = numCategories
+        }
       }
     }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 4b7c5d3f23..f680d8d3c4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -261,6 +261,19 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
     }
   }
 
+  test("training with 1-category categorical feature") {
+    val data = sc.parallelize(Seq(
+      LabeledPoint(0, Vectors.dense(0, 2, 3)),
+      LabeledPoint(1, Vectors.dense(0, 3, 1)),
+      LabeledPoint(0, Vectors.dense(0, 2, 2)),
+      LabeledPoint(1, Vectors.dense(0, 3, 9)),
+      LabeledPoint(0, Vectors.dense(0, 2, 6))
+    ))
+    val df = TreeTests.setMetadata(data, Map(0 -> 1), 2)
+    val dt = new DecisionTreeClassifier().setMaxDepth(3)
+    val model = dt.fit(df)
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
-- 
GitLab