diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 71c8c42ce5eba23b3ee59d46fd52b8b4daa259e9..0b7ad92b3cf30885bab0df3481b8eee602e1124f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -51,7 +51,7 @@ import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} * findSplits() method during initialization, after which each continuous feature becomes * an ordered discretized feature with at most maxBins possible values. * - * The main loop in the algorithm operates on a queue of nodes (nodeQueue). These nodes + * The main loop in the algorithm operates on a queue of nodes (nodeStack). These nodes * lie at the periphery of the tree being trained. If multiple trees are being trained at once, * then this queue contains nodes from all of them. Each iteration works roughly as follows: * On the master node: @@ -161,31 +161,42 @@ private[spark] object RandomForest extends Logging { None } - // FIFO queue of nodes to train: (treeIndex, node) - val nodeQueue = new mutable.Queue[(Int, LearningNode)]() + /* + Stack of nodes to train: (treeIndex, node) + The reason this is a stack is that we train many trees at once, but we want to focus on + completing trees, rather than training all simultaneously. If we are splitting nodes from + 1 tree, then the new nodes to split will be put at the top of this stack, so we will continue + training the same tree in the next iteration. This focus allows us to send fewer trees to + workers on each iteration; see topNodesForGroup below. + */ + val nodeStack = new mutable.Stack[(Int, LearningNode)] val rng = new Random() rng.setSeed(seed) // Allocate and queue root nodes. val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1)) - Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) + Range(0, numTrees).foreach(treeIndex => nodeStack.push((treeIndex, topNodes(treeIndex)))) timer.stop("init") - while (nodeQueue.nonEmpty) { + while (nodeStack.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. val (nodesForGroup, treeToNodeToIndexInfo) = - RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) // Sanity check (should never occur): assert(nodesForGroup.nonEmpty, s"RandomForest selected empty nodesForGroup. Error for unknown reason.") + // Only send trees to worker if they contain nodes being split this iteration. + val topNodesForGroup: Map[Int, LearningNode] = + nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap + // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") - RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, - treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache) + RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, nodesForGroup, + treeToNodeToIndexInfo, splits, nodeStack, timer, nodeIdCache) timer.stop("findBestSplits") } @@ -334,13 +345,14 @@ private[spark] object RandomForest extends Logging { * * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]] * @param metadata Learning and dataset metadata - * @param topNodes Root node for each tree. Used for matching instances with nodes. + * @param topNodesForGroup For each tree in group, tree index -> root node. + * Used for matching instances with nodes. * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, * where nodeIndexInfo stores the index in the group and the * feature subsets (if using feature subsets). * @param splits possible splits for all features, indexed (numFeatures)(numSplits) - * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). + * @param nodeStack Queue of nodes to split, with values (treeIndex, node). * Updated with new non-leaf nodes which are created. * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where * each value in the array is the data point's node Id @@ -351,11 +363,11 @@ private[spark] object RandomForest extends Logging { private[tree] def findBestSplits( input: RDD[BaggedPoint[TreePoint]], metadata: DecisionTreeMetadata, - topNodes: Array[LearningNode], + topNodesForGroup: Map[Int, LearningNode], nodesForGroup: Map[Int, Array[LearningNode]], treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], splits: Array[Array[Split]], - nodeQueue: mutable.Queue[(Int, LearningNode)], + nodeStack: mutable.Stack[(Int, LearningNode)], timer: TimeTracker = new TimeTracker, nodeIdCache: Option[NodeIdCache] = None): Unit = { @@ -437,7 +449,8 @@ private[spark] object RandomForest extends Logging { agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => - val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) + val nodeIndex = + topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) } agg @@ -593,10 +606,10 @@ private[spark] object RandomForest extends Logging { // enqueue left child and right child if they are not leaves if (!leftChildIsLeaf) { - nodeQueue.enqueue((treeIndex, node.leftChild.get)) + nodeStack.push((treeIndex, node.leftChild.get)) } if (!rightChildIsLeaf) { - nodeQueue.enqueue((treeIndex, node.rightChild.get)) + nodeStack.push((treeIndex, node.rightChild.get)) } logDebug("leftChildIndex = " + node.leftChild.get.id + @@ -1029,7 +1042,7 @@ private[spark] object RandomForest extends Logging { * will be needed; this allows an adaptive number of nodes since different nodes may require * different amounts of memory (if featureSubsetStrategy is not "all"). * - * @param nodeQueue Queue of nodes to split. + * @param nodeStack Queue of nodes to split. * @param maxMemoryUsage Bound on size of aggregate statistics. * @return (nodesForGroup, treeToNodeToIndexInfo). * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. @@ -1041,7 +1054,7 @@ private[spark] object RandomForest extends Logging { * The feature indices are None if not subsampling features. */ private[tree] def selectNodesToSplit( - nodeQueue: mutable.Queue[(Int, LearningNode)], + nodeStack: mutable.Stack[(Int, LearningNode)], maxMemoryUsage: Long, metadata: DecisionTreeMetadata, rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = { @@ -1054,8 +1067,8 @@ private[spark] object RandomForest extends Logging { var numNodesInGroup = 0 // If maxMemoryInMB is set very small, we want to still try to split 1 node, // so we allow one iteration if memUsage == 0. - while (nodeQueue.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) { - val (treeIndex, node) = nodeQueue.head + while (nodeStack.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) { + val (treeIndex, node) = nodeStack.top // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { Some(SamplingUtils.reservoirSampleAndCount(Range(0, @@ -1066,7 +1079,7 @@ private[spark] object RandomForest extends Logging { // Check if enough memory remains to add this node to the group. val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) { - nodeQueue.dequeue() + nodeStack.pop() mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) += node mutableTreeToNodeToIndexInfo @@ -1109,5 +1122,4 @@ private[spark] object RandomForest extends Logging { 3 * totalBins } } - } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index dcc2f305df75aa3e54ae0cdda3c765c7ed9f0ab2..79b19ea5ad2062d94d1f67300f716eff1aa89a1e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -26,7 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.tree._ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper} -import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, + Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.collection.OpenHashMap @@ -239,12 +240,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val treeToNodeToIndexInfo = Map((0, Map( (topNode.id, new RandomForest.NodeIndexInfo(0, None)) ))) - val nodeQueue = new mutable.Queue[(Int, LearningNode)]() - RandomForest.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue) + val nodeStack = new mutable.Stack[(Int, LearningNode)] + RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) // don't enqueue leaf nodes into node queue - assert(nodeQueue.isEmpty) + assert(nodeStack.isEmpty) // set impurity and predict for topNode assert(topNode.stats !== null) @@ -281,12 +282,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val treeToNodeToIndexInfo = Map((0, Map( (topNode.id, new RandomForest.NodeIndexInfo(0, None)) ))) - val nodeQueue = new mutable.Queue[(Int, LearningNode)]() - RandomForest.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue) + val nodeStack = new mutable.Stack[(Int, LearningNode)] + RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) // don't enqueue a node into node queue if its impurity is 0.0 - assert(nodeQueue.isEmpty) + assert(nodeStack.isEmpty) // set impurity and predict for topNode assert(topNode.stats !== null) @@ -393,16 +394,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val failString = s"Failed on test with:" + s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," + s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed" - val nodeQueue = new mutable.Queue[(Int, LearningNode)]() + val nodeStack = new mutable.Stack[(Int, LearningNode)] val topNodes: Array[LearningNode] = new Array[LearningNode](numTrees) Range(0, numTrees).foreach { treeIndex => topNodes(treeIndex) = LearningNode.emptyNode(nodeIndex = 1) - nodeQueue.enqueue((treeIndex, topNodes(treeIndex))) + nodeStack.push((treeIndex, topNodes(treeIndex))) } val rng = new scala.util.Random(seed = seed) val (nodesForGroup: Map[Int, Array[LearningNode]], treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = - RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) assert(nodesForGroup.size === numTrees, failString) assert(nodesForGroup.values.forall(_.length == 1), failString) // 1 node per tree @@ -546,7 +547,6 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } - } private object RandomForestSuite {