diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 82e1ed85a0a147b206bccfaf52be2ab5567fbf6b..f7d969f4ca5dbe6a441549651547a0bcd1806c26 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -1089,7 +1089,8 @@ private[spark] object RandomForest extends Logging {
     var numNodesInGroup = 0
     // If maxMemoryInMB is set very small, we want to still try to split 1 node,
     // so we allow one iteration if memUsage == 0.
-    while (nodeStack.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) {
+    var groupDone = false
+    while (nodeStack.nonEmpty && !groupDone) {
       val (treeIndex, node) = nodeStack.top
       // Choose subset of features for node (if subsampling).
       val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
@@ -1107,9 +1108,11 @@ private[spark] object RandomForest extends Logging {
         mutableTreeToNodeToIndexInfo
           .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
           = new NodeIndexInfo(numNodesInGroup, featureSubset)
+        numNodesInGroup += 1
+        memUsage += nodeMemUsage
+      } else {
+        groupDone = true
       }
-      numNodesInGroup += 1
-      memUsage += nodeMemUsage
     }
     if (memUsage > maxMemoryUsage) {
       // If maxMemoryUsage is 0, we should still allow splitting 1 node.