diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index dd9a5f261f60f86c0190fe5b063c0b18f0b0518a..afbb9d974d42a22b50614896d41ec4bafa730122 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -477,8 +477,8 @@ private[ml] object RandomForest extends Logging {
         // Construct a nodeStatsAggregators array to hold node aggregate stats,
         // each node will have a nodeStatsAggregator
         val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
-          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
-            Some(nodeToFeatures(nodeIndex))
+          val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
+            nodeToFeatures(nodeIndex)
           }
           new DTStatsAggregator(metadata, featuresForNode)
         }
@@ -827,8 +827,8 @@ private[ml] object RandomForest extends Logging {
     val numFeatures = metadata.numFeatures
 
     // Sample the input only if there are continuous features.
-    val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
-    val sampledInput = if (hasContinuousFeatures) {
+    val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
+    val sampledInput = if (continuousFeatures.nonEmpty) {
       // Calculate the number of samples for approximate quantile calculation.
       val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
       val fraction = if (requiredSamples < metadata.numExamples) {
@@ -837,58 +837,57 @@ private[ml] object RandomForest extends Logging {
         1.0
       }
       logDebug("fraction of data used for calculating quantiles = " + fraction)
-      input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
+      input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
     } else {
-      new Array[LabeledPoint](0)
+      input.sparkContext.emptyRDD[LabeledPoint]
     }
 
-    val splits = new Array[Array[Split]](numFeatures)
-
-    // Find all splits.
-    // Iterate over all features.
-    var featureIndex = 0
-    while (featureIndex < numFeatures) {
-      if (metadata.isContinuous(featureIndex)) {
-        val featureSamples = sampledInput.map(_.features(featureIndex))
-        val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex)
+    findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
+  }
 
-        val numSplits = featureSplits.length
-        logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
-        splits(featureIndex) = new Array[Split](numSplits)
+  private def findSplitsBinsBySorting(
+      input: RDD[LabeledPoint],
+      metadata: DecisionTreeMetadata,
+      continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
+
+    val continuousSplits: scala.collection.Map[Int, Array[Split]] = {
+      // reduce the parallelism for split computations when there are less
+      // continuous features than input partitions. this prevents tasks from
+      // being spun up that will definitely do no work.
+      val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
+
+      input
+        .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx))))
+        .groupByKey(numPartitions)
+        .map { case (idx, samples) =>
+          val thresholds = findSplitsForContinuousFeature(samples, metadata, idx)
+          val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh))
+          logDebug(s"featureIndex = $idx, numSplits = ${splits.length}")
+          (idx, splits)
+        }.collectAsMap()
+    }
 
-        var splitIndex = 0
-        while (splitIndex < numSplits) {
-          val threshold = featureSplits(splitIndex)
-          splits(featureIndex)(splitIndex) = new ContinuousSplit(featureIndex, threshold)
-          splitIndex += 1
-        }
-      } else {
-        // Categorical feature
-        if (metadata.isUnordered(featureIndex)) {
-          val numSplits = metadata.numSplits(featureIndex)
-          val featureArity = metadata.featureArity(featureIndex)
-          // TODO: Use an implicit representation mapping each category to a subset of indices.
-          //       I.e., track indices such that we can calculate the set of bins for which
-          //       feature value x splits to the left.
-          // Unordered features
-          // 2^(maxFeatureValue - 1) - 1 combinations
-          splits(featureIndex) = new Array[Split](numSplits)
-          var splitIndex = 0
-          while (splitIndex < numSplits) {
-            val categories: List[Double] =
-              extractMultiClassCategories(splitIndex + 1, featureArity)
-            splits(featureIndex)(splitIndex) =
-              new CategoricalSplit(featureIndex, categories.toArray, featureArity)
-            splitIndex += 1
-          }
-        } else {
-          // Ordered features
-          //   Bins correspond to feature values, so we do not need to compute splits or bins
-          //   beforehand.  Splits are constructed as needed during training.
-          splits(featureIndex) = new Array[Split](0)
+    val numFeatures = metadata.numFeatures
+    val splits: Array[Array[Split]] = Array.tabulate(numFeatures) {
+      case i if metadata.isContinuous(i) =>
+        val split = continuousSplits(i)
+        metadata.setNumSplits(i, split.length)
+        split
+
+      case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
+        // Unordered features
+        // 2^(maxFeatureValue - 1) - 1 combinations
+        val featureArity = metadata.featureArity(i)
+        Array.tabulate[Split](metadata.numSplits(i)) { splitIndex =>
+          val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
+          new CategoricalSplit(i, categories.toArray, featureArity)
         }
-      }
-      featureIndex += 1
+
+      case i if metadata.isCategorical(i) =>
+        // Ordered features
+        //   Bins correspond to feature values, so we do not need to compute splits or bins
+        //   beforehand.  Splits are constructed as needed during training.
+        Array.empty[Split]
     }
     splits
   }
@@ -930,7 +929,7 @@ private[ml] object RandomForest extends Logging {
    * @return array of splits
    */
   private[tree] def findSplitsForContinuousFeature(
-      featureSamples: Array[Double],
+      featureSamples: Iterable[Double],
       metadata: DecisionTreeMetadata,
       featureIndex: Int): Array[Double] = {
     require(metadata.isContinuous(featureIndex),
@@ -940,8 +939,9 @@ private[ml] object RandomForest extends Logging {
       val numSplits = metadata.numSplits(featureIndex)
 
       // get count for each distinct value
-      val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
-        m + ((x, m.getOrElse(x, 0) + 1))
+      val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
+        case ((m, cnt), x) =>
+          (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
       }
       // sort distinct values
       val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
@@ -952,7 +952,7 @@ private[ml] object RandomForest extends Logging {
         valueCounts.map(_._1)
       } else {
         // stride between splits
-        val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
+        val stride: Double = numSamples.toDouble / (numSplits + 1)
         logDebug("stride = " + stride)
 
         // iterate `valueCount` to find splits
@@ -988,8 +988,6 @@ private[ml] object RandomForest extends Logging {
     assert(splits.length > 0,
       s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
         "  Please remove this feature and then try again.")
-    // set number of splits accordingly
-    metadata.setNumSplits(featureIndex, splits.length)
 
     splits
   }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index c0934d241f50a59018706267729d61e0191448db..8f02e098acc304a0fa51e980af65960611953c5d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -1010,7 +1010,7 @@ object DecisionTree extends Serializable with Logging {
         featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = {
       val splits = {
         val featureSplits = findSplitsForContinuousFeature(
-          featureSamples.toArray,
+          featureSamples,
           metadata,
           featureIndex)
         logDebug(s"featureIndex = $featureIndex, numSplits = ${featureSplits.length}")
@@ -1115,7 +1115,7 @@ object DecisionTree extends Serializable with Logging {
    * @return Array of splits.
    */
   private[tree] def findSplitsForContinuousFeature(
-      featureSamples: Array[Double],
+      featureSamples: Iterable[Double],
       metadata: DecisionTreeMetadata,
       featureIndex: Int): Array[Double] = {
     require(metadata.isContinuous(featureIndex),
@@ -1125,8 +1125,9 @@ object DecisionTree extends Serializable with Logging {
       val numSplits = metadata.numSplits(featureIndex)
 
       // get count for each distinct value
-      val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
-        m + ((x, m.getOrElse(x, 0) + 1))
+      val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
+        case ((m, cnt), x) =>
+          (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
       }
       // sort distinct values
       val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
@@ -1137,7 +1138,7 @@ object DecisionTree extends Serializable with Logging {
         valueCounts.map(_._1)
       } else {
         // stride between splits
-        val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
+        val stride: Double = numSamples.toDouble / (numSplits + 1)
         logDebug("stride = " + stride)
 
         // iterate `valueCount` to find splits