diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 82e1ed85a0a147b206bccfaf52be2ab5567fbf6b..f7d969f4ca5dbe6a441549651547a0bcd1806c26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1089,7 +1089,8 @@ private[spark] object RandomForest extends Logging { var numNodesInGroup = 0 // If maxMemoryInMB is set very small, we want to still try to split 1 node, // so we allow one iteration if memUsage == 0. - while (nodeStack.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) { + var groupDone = false + while (nodeStack.nonEmpty && !groupDone) { val (treeIndex, node) = nodeStack.top // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { @@ -1107,9 +1108,11 @@ private[spark] object RandomForest extends Logging { mutableTreeToNodeToIndexInfo .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) = new NodeIndexInfo(numNodesInGroup, featureSubset) + numNodesInGroup += 1 + memUsage += nodeMemUsage + } else { + groupDone = true } - numNodesInGroup += 1 - memUsage += nodeMemUsage } if (memUsage > maxMemoryUsage) { // If maxMemoryUsage is 0, we should still allow splitting 1 node.