diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 71c8c42ce5eba23b3ee59d46fd52b8b4daa259e9..0b7ad92b3cf30885bab0df3481b8eee602e1124f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -51,7 +51,7 @@ import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
  * findSplits() method during initialization, after which each continuous feature becomes
  * an ordered discretized feature with at most maxBins possible values.
  *
- * The main loop in the algorithm operates on a queue of nodes (nodeQueue).  These nodes
+ * The main loop in the algorithm operates on a queue of nodes (nodeStack).  These nodes
  * lie at the periphery of the tree being trained.  If multiple trees are being trained at once,
  * then this queue contains nodes from all of them.  Each iteration works roughly as follows:
  *   On the master node:
@@ -161,31 +161,42 @@ private[spark] object RandomForest extends Logging {
       None
     }
 
-    // FIFO queue of nodes to train: (treeIndex, node)
-    val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+    /*
+      Stack of nodes to train: (treeIndex, node)
+      The reason this is a stack is that we train many trees at once, but we want to focus on
+      completing trees, rather than training all simultaneously.  If we are splitting nodes from
+      1 tree, then the new nodes to split will be put at the top of this stack, so we will continue
+      training the same tree in the next iteration.  This focus allows us to send fewer trees to
+      workers on each iteration; see topNodesForGroup below.
+     */
+    val nodeStack = new mutable.Stack[(Int, LearningNode)]
 
     val rng = new Random()
     rng.setSeed(seed)
 
     // Allocate and queue root nodes.
     val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
-    Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
+    Range(0, numTrees).foreach(treeIndex => nodeStack.push((treeIndex, topNodes(treeIndex))))
 
     timer.stop("init")
 
-    while (nodeQueue.nonEmpty) {
+    while (nodeStack.nonEmpty) {
       // Collect some nodes to split, and choose features for each node (if subsampling).
       // Each group of nodes may come from one or multiple trees, and at multiple levels.
       val (nodesForGroup, treeToNodeToIndexInfo) =
-        RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
+        RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
       // Sanity check (should never occur):
       assert(nodesForGroup.nonEmpty,
         s"RandomForest selected empty nodesForGroup.  Error for unknown reason.")
 
+      // Only send trees to worker if they contain nodes being split this iteration.
+      val topNodesForGroup: Map[Int, LearningNode] =
+        nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap
+
       // Choose node splits, and enqueue new nodes as needed.
       timer.start("findBestSplits")
-      RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
-        treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache)
+      RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, nodesForGroup,
+        treeToNodeToIndexInfo, splits, nodeStack, timer, nodeIdCache)
       timer.stop("findBestSplits")
     }
 
@@ -334,13 +345,14 @@ private[spark] object RandomForest extends Logging {
    *
    * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]]
    * @param metadata Learning and dataset metadata
-   * @param topNodes Root node for each tree.  Used for matching instances with nodes.
+   * @param topNodesForGroup For each tree in group, tree index -> root node.
+   *                         Used for matching instances with nodes.
    * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
    * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
    *                              where nodeIndexInfo stores the index in the group and the
    *                              feature subsets (if using feature subsets).
    * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
-   * @param nodeQueue  Queue of nodes to split, with values (treeIndex, node).
+   * @param nodeStack  Queue of nodes to split, with values (treeIndex, node).
    *                   Updated with new non-leaf nodes which are created.
    * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
    *                    each value in the array is the data point's node Id
@@ -351,11 +363,11 @@ private[spark] object RandomForest extends Logging {
   private[tree] def findBestSplits(
       input: RDD[BaggedPoint[TreePoint]],
       metadata: DecisionTreeMetadata,
-      topNodes: Array[LearningNode],
+      topNodesForGroup: Map[Int, LearningNode],
       nodesForGroup: Map[Int, Array[LearningNode]],
       treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
       splits: Array[Array[Split]],
-      nodeQueue: mutable.Queue[(Int, LearningNode)],
+      nodeStack: mutable.Stack[(Int, LearningNode)],
       timer: TimeTracker = new TimeTracker,
       nodeIdCache: Option[NodeIdCache] = None): Unit = {
 
@@ -437,7 +449,8 @@ private[spark] object RandomForest extends Logging {
         agg: Array[DTStatsAggregator],
         baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
       treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
-        val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
+        val nodeIndex =
+          topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
         nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
       }
       agg
@@ -593,10 +606,10 @@ private[spark] object RandomForest extends Logging {
 
           // enqueue left child and right child if they are not leaves
           if (!leftChildIsLeaf) {
-            nodeQueue.enqueue((treeIndex, node.leftChild.get))
+            nodeStack.push((treeIndex, node.leftChild.get))
           }
           if (!rightChildIsLeaf) {
-            nodeQueue.enqueue((treeIndex, node.rightChild.get))
+            nodeStack.push((treeIndex, node.rightChild.get))
           }
 
           logDebug("leftChildIndex = " + node.leftChild.get.id +
@@ -1029,7 +1042,7 @@ private[spark] object RandomForest extends Logging {
    * will be needed; this allows an adaptive number of nodes since different nodes may require
    * different amounts of memory (if featureSubsetStrategy is not "all").
    *
-   * @param nodeQueue  Queue of nodes to split.
+   * @param nodeStack  Queue of nodes to split.
    * @param maxMemoryUsage  Bound on size of aggregate statistics.
    * @return  (nodesForGroup, treeToNodeToIndexInfo).
    *          nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
@@ -1041,7 +1054,7 @@ private[spark] object RandomForest extends Logging {
    *          The feature indices are None if not subsampling features.
    */
   private[tree] def selectNodesToSplit(
-      nodeQueue: mutable.Queue[(Int, LearningNode)],
+      nodeStack: mutable.Stack[(Int, LearningNode)],
       maxMemoryUsage: Long,
       metadata: DecisionTreeMetadata,
       rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
@@ -1054,8 +1067,8 @@ private[spark] object RandomForest extends Logging {
     var numNodesInGroup = 0
     // If maxMemoryInMB is set very small, we want to still try to split 1 node,
     // so we allow one iteration if memUsage == 0.
-    while (nodeQueue.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) {
-      val (treeIndex, node) = nodeQueue.head
+    while (nodeStack.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) {
+      val (treeIndex, node) = nodeStack.top
       // Choose subset of features for node (if subsampling).
       val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
         Some(SamplingUtils.reservoirSampleAndCount(Range(0,
@@ -1066,7 +1079,7 @@ private[spark] object RandomForest extends Logging {
       // Check if enough memory remains to add this node to the group.
       val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
       if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) {
-        nodeQueue.dequeue()
+        nodeStack.pop()
         mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
           node
         mutableTreeToNodeToIndexInfo
@@ -1109,5 +1122,4 @@ private[spark] object RandomForest extends Logging {
       3 * totalBins
     }
   }
-
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index dcc2f305df75aa3e54ae0cdda3c765c7ed9f0ab2..79b19ea5ad2062d94d1f67300f716eff1aa89a1e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -26,7 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.tree._
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy,
+  Strategy => OldStrategy}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.util.collection.OpenHashMap
@@ -239,12 +240,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
     val treeToNodeToIndexInfo = Map((0, Map(
       (topNode.id, new RandomForest.NodeIndexInfo(0, None))
     )))
-    val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
-    RandomForest.findBestSplits(baggedInput, metadata, Array(topNode),
-      nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
+    val nodeStack = new mutable.Stack[(Int, LearningNode)]
+    RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
+      nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
 
     // don't enqueue leaf nodes into node queue
-    assert(nodeQueue.isEmpty)
+    assert(nodeStack.isEmpty)
 
     // set impurity and predict for topNode
     assert(topNode.stats !== null)
@@ -281,12 +282,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
     val treeToNodeToIndexInfo = Map((0, Map(
       (topNode.id, new RandomForest.NodeIndexInfo(0, None))
     )))
-    val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
-    RandomForest.findBestSplits(baggedInput, metadata, Array(topNode),
-      nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
+    val nodeStack = new mutable.Stack[(Int, LearningNode)]
+    RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
+      nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
 
     // don't enqueue a node into node queue if its impurity is 0.0
-    assert(nodeQueue.isEmpty)
+    assert(nodeStack.isEmpty)
 
     // set impurity and predict for topNode
     assert(topNode.stats !== null)
@@ -393,16 +394,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
         val failString = s"Failed on test with:" +
           s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," +
           s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
-        val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+        val nodeStack = new mutable.Stack[(Int, LearningNode)]
         val topNodes: Array[LearningNode] = new Array[LearningNode](numTrees)
         Range(0, numTrees).foreach { treeIndex =>
           topNodes(treeIndex) = LearningNode.emptyNode(nodeIndex = 1)
-          nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))
+          nodeStack.push((treeIndex, topNodes(treeIndex)))
         }
         val rng = new scala.util.Random(seed = seed)
         val (nodesForGroup: Map[Int, Array[LearningNode]],
         treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) =
-          RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
+          RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
 
         assert(nodesForGroup.size === numTrees, failString)
         assert(nodesForGroup.values.forall(_.length == 1), failString) // 1 node per tree
@@ -546,7 +547,6 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
     val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
     assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
   }
-
 }
 
 private object RandomForestSuite {