diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index d3376a7dff938fb2d06e0ff4ca41884f15c6224b..ea733d577a5fd71bfd2678a32b7986e7cdbb697c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -650,7 +650,7 @@ private[ml] object RandomForest extends Logging {
    * @param binAggregates Bin statistics.
    * @return tuple for best split: (Split, information gain, prediction at node)
    */
-  private def binsToBestSplit(
+  private[tree] def binsToBestSplit(
       binAggregates: DTStatsAggregator,
       splits: Array[Array[Split]],
       featuresForNode: Option[Array[Int]],
@@ -720,32 +720,30 @@ private[ml] object RandomForest extends Logging {
            *
            * centroidForCategories is a list: (category, centroid)
            */
-          val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
-            // For categorical variables in multiclass classification,
-            // the bins are ordered by the impurity of their corresponding labels.
-            Range(0, numCategories).map { case featureValue =>
-              val categoryStats =
-                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
-              val centroid = if (categoryStats.count != 0) {
+          val centroidForCategories = Range(0, numCategories).map { case featureValue =>
+            val categoryStats =
+              binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+            val centroid = if (categoryStats.count != 0) {
+              if (binAggregates.metadata.isMulticlass) {
+                // multiclass classification
+                // For categorical variables in multiclass classification,
+                // the bins are ordered by the impurity of their corresponding labels.
                 categoryStats.calculate()
+              } else if (binAggregates.metadata.isClassification) {
+                // binary classification
+                // For categorical variables in binary classification,
+                // the bins are ordered by the count of class 1.
+                categoryStats.stats(1)
               } else {
-                Double.MaxValue
-              }
-              (featureValue, centroid)
-            }
-          } else { // regression or binary classification
-            // For categorical variables in regression and binary classification,
-            // the bins are ordered by the centroid of their corresponding labels.
-            Range(0, numCategories).map { case featureValue =>
-              val categoryStats =
-                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
-              val centroid = if (categoryStats.count != 0) {
+                // regression
+                // For categorical variables in regression and binary classification,
+                // the bins are ordered by the prediction.
                 categoryStats.predict
-              } else {
-                Double.MaxValue
               }
-              (featureValue, centroid)
+            } else {
+              Double.MaxValue
             }
+            (featureValue, centroid)
           }
 
           logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 07ba0d8ccb2a8d473206f8ab3eccd500216899fc..51235a23711a167a99c7ca2072a41603ef27b225 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -791,7 +791,7 @@ object DecisionTree extends Serializable with Logging {
    * @param binAggregates Bin statistics.
    * @return tuple for best split: (Split, information gain, prediction at node)
    */
-  private def binsToBestSplit(
+  private[tree] def binsToBestSplit(
       binAggregates: DTStatsAggregator,
       splits: Array[Array[Split]],
       featuresForNode: Option[Array[Int]],
@@ -808,128 +808,127 @@ object DecisionTree extends Serializable with Logging {
     // For each (feature, split), calculate the gain, and select the best (feature, split).
     val (bestSplit, bestSplitStats) =
       Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
-      val featureIndex = if (featuresForNode.nonEmpty) {
-        featuresForNode.get.apply(featureIndexIdx)
-      } else {
-        featureIndexIdx
-      }
-      val numSplits = binAggregates.metadata.numSplits(featureIndex)
-      if (binAggregates.metadata.isContinuous(featureIndex)) {
-        // Cumulative sum (scanLeft) of bin statistics.
-        // Afterwards, binAggregates for a bin is the sum of aggregates for
-        // that bin + all preceding bins.
-        val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
-        var splitIndex = 0
-        while (splitIndex < numSplits) {
-          binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
-          splitIndex += 1
+        val featureIndex = if (featuresForNode.nonEmpty) {
+          featuresForNode.get.apply(featureIndexIdx)
+        } else {
+          featureIndexIdx
         }
-        // Find best split.
-        val (bestFeatureSplitIndex, bestFeatureGainStats) =
-          Range(0, numSplits).map { case splitIdx =>
-            val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
-            val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
-            rightChildStats.subtract(leftChildStats)
-            predictWithImpurity = Some(predictWithImpurity.getOrElse(
-              calculatePredictImpurity(leftChildStats, rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats,
-              rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
-            (splitIdx, gainStats)
-          }.maxBy(_._2.gain)
-        (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
-      } else if (binAggregates.metadata.isUnordered(featureIndex)) {
-        // Unordered categorical feature
-        val (leftChildOffset, rightChildOffset) =
-          binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
-        val (bestFeatureSplitIndex, bestFeatureGainStats) =
-          Range(0, numSplits).map { splitIndex =>
-            val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
-            val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
-            predictWithImpurity = Some(predictWithImpurity.getOrElse(
-              calculatePredictImpurity(leftChildStats, rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats,
-              rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
-            (splitIndex, gainStats)
-          }.maxBy(_._2.gain)
-        (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
-      } else {
-        // Ordered categorical feature
-        val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
-        val numBins = binAggregates.metadata.numBins(featureIndex)
-
-        /* Each bin is one category (feature value).
-         * The bins are ordered based on centroidForCategories, and this ordering determines which
-         * splits are considered.  (With K categories, we consider K - 1 possible splits.)
-         *
-         * centroidForCategories is a list: (category, centroid)
-         */
-        val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
-          // For categorical variables in multiclass classification,
-          // the bins are ordered by the impurity of their corresponding labels.
-          Range(0, numBins).map { case featureValue =>
-            val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
-            val centroid = if (categoryStats.count != 0) {
-              categoryStats.calculate()
-            } else {
-              Double.MaxValue
-            }
-            (featureValue, centroid)
+        val numSplits = binAggregates.metadata.numSplits(featureIndex)
+        if (binAggregates.metadata.isContinuous(featureIndex)) {
+          // Cumulative sum (scanLeft) of bin statistics.
+          // Afterwards, binAggregates for a bin is the sum of aggregates for
+          // that bin + all preceding bins.
+          val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+          var splitIndex = 0
+          while (splitIndex < numSplits) {
+            binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
+            splitIndex += 1
           }
-        } else { // regression or binary classification
-          // For categorical variables in regression and binary classification,
-          // the bins are ordered by the centroid of their corresponding labels.
-          Range(0, numBins).map { case featureValue =>
-            val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+          // Find best split.
+          val (bestFeatureSplitIndex, bestFeatureGainStats) =
+            Range(0, numSplits).map { case splitIdx =>
+              val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+              val rightChildStats =
+                binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
+              rightChildStats.subtract(leftChildStats)
+              predictWithImpurity = Some(predictWithImpurity.getOrElse(
+                calculatePredictImpurity(leftChildStats, rightChildStats)))
+              val gainStats = calculateGainForSplit(leftChildStats,
+                rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
+              (splitIdx, gainStats)
+            }.maxBy(_._2.gain)
+          (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+        } else if (binAggregates.metadata.isUnordered(featureIndex)) {
+          // Unordered categorical feature
+          val (leftChildOffset, rightChildOffset) =
+            binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
+          val (bestFeatureSplitIndex, bestFeatureGainStats) =
+            Range(0, numSplits).map { splitIndex =>
+              val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
+              val rightChildStats =
+                binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+              predictWithImpurity = Some(predictWithImpurity.getOrElse(
+                calculatePredictImpurity(leftChildStats, rightChildStats)))
+              val gainStats = calculateGainForSplit(leftChildStats,
+                rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
+              (splitIndex, gainStats)
+            }.maxBy(_._2.gain)
+          (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+        } else {
+          // Ordered categorical feature
+          val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+          val numBins = binAggregates.metadata.numBins(featureIndex)
+
+          /* Each bin is one category (feature value).
+           * The bins are ordered based on centroidForCategories, and this ordering determines which
+           * splits are considered.  (With K categories, we consider K - 1 possible splits.)
+           *
+           * centroidForCategories is a list: (category, centroid)
+           */
+          val centroidForCategories = Range(0, numBins).map { case featureValue =>
+            val categoryStats =
+              binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
             val centroid = if (categoryStats.count != 0) {
-              categoryStats.predict
+              if (binAggregates.metadata.isMulticlass) {
+                // For categorical variables in multiclass classification,
+                // the bins are ordered by the impurity of their corresponding labels.
+                categoryStats.calculate()
+              } else if (binAggregates.metadata.isClassification) {
+                // For categorical variables in binary classification,
+                // the bins are ordered by the count of class 1.
+                categoryStats.stats(1)
+              } else {
+                // For categorical variables in regression,
+                // the bins are ordered by the prediction.
+                categoryStats.predict
+              }
             } else {
               Double.MaxValue
             }
             (featureValue, centroid)
           }
-        }
 
-        logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
+          logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
 
-        // bins sorted by centroids
-        val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
+          // bins sorted by centroids
+          val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
 
-        logDebug("Sorted centroids for categorical variable = " +
-          categoriesSortedByCentroid.mkString(","))
+          logDebug("Sorted centroids for categorical variable = " +
+            categoriesSortedByCentroid.mkString(","))
 
-        // Cumulative sum (scanLeft) of bin statistics.
-        // Afterwards, binAggregates for a bin is the sum of aggregates for
-        // that bin + all preceding bins.
-        var splitIndex = 0
-        while (splitIndex < numSplits) {
-          val currentCategory = categoriesSortedByCentroid(splitIndex)._1
-          val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
-          binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
-          splitIndex += 1
+          // Cumulative sum (scanLeft) of bin statistics.
+          // Afterwards, binAggregates for a bin is the sum of aggregates for
+          // that bin + all preceding bins.
+          var splitIndex = 0
+          while (splitIndex < numSplits) {
+            val currentCategory = categoriesSortedByCentroid(splitIndex)._1
+            val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
+            binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
+            splitIndex += 1
+          }
+          // lastCategory = index of bin with total aggregates for this (node, feature)
+          val lastCategory = categoriesSortedByCentroid.last._1
+          // Find best split.
+          val (bestFeatureSplitIndex, bestFeatureGainStats) =
+            Range(0, numSplits).map { splitIndex =>
+              val featureValue = categoriesSortedByCentroid(splitIndex)._1
+              val leftChildStats =
+                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+              val rightChildStats =
+                binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
+              rightChildStats.subtract(leftChildStats)
+              predictWithImpurity = Some(predictWithImpurity.getOrElse(
+                calculatePredictImpurity(leftChildStats, rightChildStats)))
+              val gainStats = calculateGainForSplit(leftChildStats,
+                rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
+              (splitIndex, gainStats)
+            }.maxBy(_._2.gain)
+          val categoriesForSplit =
+            categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
+          val bestFeatureSplit =
+            new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
+          (bestFeatureSplit, bestFeatureGainStats)
         }
-        // lastCategory = index of bin with total aggregates for this (node, feature)
-        val lastCategory = categoriesSortedByCentroid.last._1
-        // Find best split.
-        val (bestFeatureSplitIndex, bestFeatureGainStats) =
-          Range(0, numSplits).map { splitIndex =>
-            val featureValue = categoriesSortedByCentroid(splitIndex)._1
-            val leftChildStats =
-              binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
-            val rightChildStats =
-              binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
-            rightChildStats.subtract(leftChildStats)
-            predictWithImpurity = Some(predictWithImpurity.getOrElse(
-              calculatePredictImpurity(leftChildStats, rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats,
-              rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
-            (splitIndex, gainStats)
-          }.maxBy(_._2.gain)
-        val categoriesForSplit =
-          categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
-        val bestFeatureSplit =
-          new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
-        (bestFeatureSplit, bestFeatureGainStats)
-      }
     }.maxBy(_._2.gain)
 
     (bestSplit, bestSplitStats, predictWithImpurity.get._1)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index fda2711fed0fd2382c70c9a15e6c256d9dc475a9..baf6b9083900f2c5493d5dc76c9a948b67d232c8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.impl.TreeTests
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
 import org.apache.spark.ml.util.MLTestingUtils
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
@@ -275,6 +275,40 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
     val model = dt.fit(df)
   }
 
+  test("Use soft prediction for binary classification with ordered categorical features") {
+    // The following dataset is set up such that the best split is {1} vs. {0, 2}.
+    // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
+    val arr = Array(
+      LabeledPoint(0.0, Vectors.dense(0.0)),
+      LabeledPoint(0.0, Vectors.dense(0.0)),
+      LabeledPoint(0.0, Vectors.dense(0.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(2.0)),
+      LabeledPoint(0.0, Vectors.dense(2.0)),
+      LabeledPoint(0.0, Vectors.dense(2.0)),
+      LabeledPoint(1.0, Vectors.dense(2.0)))
+    val data = sc.parallelize(arr)
+    val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
+
+    // Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
+    val dt = new DecisionTreeClassifier()
+      .setImpurity("gini")
+      .setMaxDepth(1)
+      .setMaxBins(3)
+    val model = dt.fit(df)
+    model.rootNode match {
+      case n: InternalNode =>
+        n.split match {
+          case s: CategoricalSplit =>
+            assert(s.leftCategories === Array(1.0))
+        }
+    }
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index a9c935bd4244509fb78863761b7199f9db1c5d95..dca8ea815aa6adef98014d116cbda4d59a1cf224 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.util.Utils
 
 
@@ -337,6 +338,35 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(topNode.rightNode.get.impurity === 0.0)
   }
 
+  test("Use soft prediction for binary classification with ordered categorical features") {
+    // The following dataset is set up such that the best split is {1} vs. {0, 2}.
+    // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
+    val arr = Array(
+      LabeledPoint(0.0, Vectors.dense(0.0)),
+      LabeledPoint(0.0, Vectors.dense(0.0)),
+      LabeledPoint(0.0, Vectors.dense(0.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(2.0)),
+      LabeledPoint(0.0, Vectors.dense(2.0)),
+      LabeledPoint(0.0, Vectors.dense(2.0)),
+      LabeledPoint(1.0, Vectors.dense(2.0)))
+    val input = sc.parallelize(arr)
+
+    // Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
+      numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
+
+    val model = new DecisionTree(strategy).run(input)
+    model.topNode.split.get match {
+      case Split(_, _, _, categories: List[Double]) =>
+        assert(categories === List(1.0))
+    }
+  }
+
   test("Second level node building with vs. without groups") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
     assert(arr.length === 1000)