diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 9fe264656ede7d200dd1c38c06a0a1dd54a019d9..21ee49c45788c8e715116f6e7c702347d23fe689 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -144,21 +144,28 @@ private[spark] object DecisionTreeMetadata extends Logging {
       val maxCategoriesForUnorderedFeature =
         ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
       strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
-        // Decide if some categorical features should be treated as unordered features,
-        //  which require 2 * ((1 << numCategories - 1) - 1) bins.
-        // We do this check with log values to prevent overflows in case numCategories is large.
-        // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
-        if (numCategories <= maxCategoriesForUnorderedFeature) {
-          unorderedFeatures.add(featureIndex)
-          numBins(featureIndex) = numUnorderedBins(numCategories)
-        } else {
-          numBins(featureIndex) = numCategories
+        // Hack: If a categorical feature has only 1 category, we treat it as continuous.
+        // TODO(SPARK-9957): Handle this properly by filtering out those features.
+        if (numCategories > 1) {
+          // Decide if some categorical features should be treated as unordered features,
+          //  which require 2 * ((1 << numCategories - 1) - 1) bins.
+          // We do this check with log values to prevent overflows in case numCategories is large.
+          // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
+          if (numCategories <= maxCategoriesForUnorderedFeature) {
+            unorderedFeatures.add(featureIndex)
+            numBins(featureIndex) = numUnorderedBins(numCategories)
+          } else {
+            numBins(featureIndex) = numCategories
+          }
         }
       }
     } else {
       // Binary classification or regression
       strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
-        numBins(featureIndex) = numCategories
+        // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957
+        if (numCategories > 1) {
+          numBins(featureIndex) = numCategories
+        }
       }
     }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 4b7c5d3f23d2cf2da7a0684a8265e91fa4ee8beb..f680d8d3c4cc2c6ff582a5e4ef72492b9bfc8441 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -261,6 +261,19 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
     }
   }
 
+  test("training with 1-category categorical feature") {
+    val data = sc.parallelize(Seq(
+      LabeledPoint(0, Vectors.dense(0, 2, 3)),
+      LabeledPoint(1, Vectors.dense(0, 3, 1)),
+      LabeledPoint(0, Vectors.dense(0, 2, 2)),
+      LabeledPoint(1, Vectors.dense(0, 3, 9)),
+      LabeledPoint(0, Vectors.dense(0, 2, 6))
+    ))
+    val df = TreeTests.setMetadata(data, Map(0 -> 1), 2)
+    val dt = new DecisionTreeClassifier().setMaxDepth(3)
+    val model = dt.fit(df)
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////