Skip to content
Snippets Groups Projects
Commit 947b8c6e authored by Joseph K. Bradley's avatar Joseph K. Bradley
Browse files

[SPARK-16719][ML] Random Forests should communicate fewer trees on each iteration

## What changes were proposed in this pull request?

RandomForest currently sends the entire forest to each worker on each iteration. This is because (a) the node queue is FIFO and (b) the closure references the entire array of trees (topNodes). (a) causes RFs to handle splits in many trees, especially early on in learning. (b) sends all trees explicitly.

This PR:
(a) Change the RF node queue to be FILO (a stack), so that RFs tend to focus on 1 or a few trees before focusing on others.
(b) Change topNodes to pass only the trees required on that iteration.

## How was this patch tested?

Unit tests:
* Existing tests for correctness of tree learning
* Manually modifying code and running tests to verify that a small number of trees are communicated on each iteration
  * This last item is hard to test via unit tests given the current APIs.

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #14359 from jkbradley/rfs-fewer-trees.
parent a4aeb767
No related branches found
No related tags found
No related merge requests found
......@@ -51,7 +51,7 @@ import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
* findSplits() method during initialization, after which each continuous feature becomes
* an ordered discretized feature with at most maxBins possible values.
*
* The main loop in the algorithm operates on a queue of nodes (nodeQueue). These nodes
* The main loop in the algorithm operates on a queue of nodes (nodeStack). These nodes
* lie at the periphery of the tree being trained. If multiple trees are being trained at once,
* then this queue contains nodes from all of them. Each iteration works roughly as follows:
* On the master node:
......@@ -161,31 +161,42 @@ private[spark] object RandomForest extends Logging {
None
}
// FIFO queue of nodes to train: (treeIndex, node)
val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
/*
Stack of nodes to train: (treeIndex, node)
The reason this is a stack is that we train many trees at once, but we want to focus on
completing trees, rather than training all simultaneously. If we are splitting nodes from
1 tree, then the new nodes to split will be put at the top of this stack, so we will continue
training the same tree in the next iteration. This focus allows us to send fewer trees to
workers on each iteration; see topNodesForGroup below.
*/
val nodeStack = new mutable.Stack[(Int, LearningNode)]
val rng = new Random()
rng.setSeed(seed)
// Allocate and queue root nodes.
val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
Range(0, numTrees).foreach(treeIndex => nodeStack.push((treeIndex, topNodes(treeIndex))))
timer.stop("init")
while (nodeQueue.nonEmpty) {
while (nodeStack.nonEmpty) {
// Collect some nodes to split, and choose features for each node (if subsampling).
// Each group of nodes may come from one or multiple trees, and at multiple levels.
val (nodesForGroup, treeToNodeToIndexInfo) =
RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
// Sanity check (should never occur):
assert(nodesForGroup.nonEmpty,
s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
// Only send trees to worker if they contain nodes being split this iteration.
val topNodesForGroup: Map[Int, LearningNode] =
nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache)
RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, nodesForGroup,
treeToNodeToIndexInfo, splits, nodeStack, timer, nodeIdCache)
timer.stop("findBestSplits")
}
......@@ -334,13 +345,14 @@ private[spark] object RandomForest extends Logging {
*
* @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]]
* @param metadata Learning and dataset metadata
* @param topNodes Root node for each tree. Used for matching instances with nodes.
* @param topNodesForGroup For each tree in group, tree index -> root node.
* Used for matching instances with nodes.
* @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
* @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
* where nodeIndexInfo stores the index in the group and the
* feature subsets (if using feature subsets).
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
* @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
* @param nodeStack Queue of nodes to split, with values (treeIndex, node).
* Updated with new non-leaf nodes which are created.
* @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
* each value in the array is the data point's node Id
......@@ -351,11 +363,11 @@ private[spark] object RandomForest extends Logging {
private[tree] def findBestSplits(
input: RDD[BaggedPoint[TreePoint]],
metadata: DecisionTreeMetadata,
topNodes: Array[LearningNode],
topNodesForGroup: Map[Int, LearningNode],
nodesForGroup: Map[Int, Array[LearningNode]],
treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
splits: Array[Array[Split]],
nodeQueue: mutable.Queue[(Int, LearningNode)],
nodeStack: mutable.Stack[(Int, LearningNode)],
timer: TimeTracker = new TimeTracker,
nodeIdCache: Option[NodeIdCache] = None): Unit = {
......@@ -437,7 +449,8 @@ private[spark] object RandomForest extends Logging {
agg: Array[DTStatsAggregator],
baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
val nodeIndex =
topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
}
agg
......@@ -593,10 +606,10 @@ private[spark] object RandomForest extends Logging {
// enqueue left child and right child if they are not leaves
if (!leftChildIsLeaf) {
nodeQueue.enqueue((treeIndex, node.leftChild.get))
nodeStack.push((treeIndex, node.leftChild.get))
}
if (!rightChildIsLeaf) {
nodeQueue.enqueue((treeIndex, node.rightChild.get))
nodeStack.push((treeIndex, node.rightChild.get))
}
logDebug("leftChildIndex = " + node.leftChild.get.id +
......@@ -1029,7 +1042,7 @@ private[spark] object RandomForest extends Logging {
* will be needed; this allows an adaptive number of nodes since different nodes may require
* different amounts of memory (if featureSubsetStrategy is not "all").
*
* @param nodeQueue Queue of nodes to split.
* @param nodeStack Queue of nodes to split.
* @param maxMemoryUsage Bound on size of aggregate statistics.
* @return (nodesForGroup, treeToNodeToIndexInfo).
* nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
......@@ -1041,7 +1054,7 @@ private[spark] object RandomForest extends Logging {
* The feature indices are None if not subsampling features.
*/
private[tree] def selectNodesToSplit(
nodeQueue: mutable.Queue[(Int, LearningNode)],
nodeStack: mutable.Stack[(Int, LearningNode)],
maxMemoryUsage: Long,
metadata: DecisionTreeMetadata,
rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
......@@ -1054,8 +1067,8 @@ private[spark] object RandomForest extends Logging {
var numNodesInGroup = 0
// If maxMemoryInMB is set very small, we want to still try to split 1 node,
// so we allow one iteration if memUsage == 0.
while (nodeQueue.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) {
val (treeIndex, node) = nodeQueue.head
while (nodeStack.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) {
val (treeIndex, node) = nodeStack.top
// Choose subset of features for node (if subsampling).
val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
Some(SamplingUtils.reservoirSampleAndCount(Range(0,
......@@ -1066,7 +1079,7 @@ private[spark] object RandomForest extends Logging {
// Check if enough memory remains to add this node to the group.
val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) {
nodeQueue.dequeue()
nodeStack.pop()
mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
node
mutableTreeToNodeToIndexInfo
......@@ -1109,5 +1122,4 @@ private[spark] object RandomForest extends Logging {
3 * totalBins
}
}
}
......@@ -26,7 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy,
Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.collection.OpenHashMap
......@@ -239,12 +240,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val treeToNodeToIndexInfo = Map((0, Map(
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
)))
val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
RandomForest.findBestSplits(baggedInput, metadata, Array(topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
val nodeStack = new mutable.Stack[(Int, LearningNode)]
RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
// don't enqueue leaf nodes into node queue
assert(nodeQueue.isEmpty)
assert(nodeStack.isEmpty)
// set impurity and predict for topNode
assert(topNode.stats !== null)
......@@ -281,12 +282,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val treeToNodeToIndexInfo = Map((0, Map(
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
)))
val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
RandomForest.findBestSplits(baggedInput, metadata, Array(topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
val nodeStack = new mutable.Stack[(Int, LearningNode)]
RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
// don't enqueue a node into node queue if its impurity is 0.0
assert(nodeQueue.isEmpty)
assert(nodeStack.isEmpty)
// set impurity and predict for topNode
assert(topNode.stats !== null)
......@@ -393,16 +394,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val failString = s"Failed on test with:" +
s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," +
s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
val nodeStack = new mutable.Stack[(Int, LearningNode)]
val topNodes: Array[LearningNode] = new Array[LearningNode](numTrees)
Range(0, numTrees).foreach { treeIndex =>
topNodes(treeIndex) = LearningNode.emptyNode(nodeIndex = 1)
nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))
nodeStack.push((treeIndex, topNodes(treeIndex)))
}
val rng = new scala.util.Random(seed = seed)
val (nodesForGroup: Map[Int, Array[LearningNode]],
treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) =
RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
assert(nodesForGroup.size === numTrees, failString)
assert(nodesForGroup.values.forall(_.length == 1), failString) // 1 node per tree
......@@ -546,7 +547,6 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
}
private object RandomForestSuite {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment