diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index cd24931293903a7d2f1fedeb5e67dd4b6c700dfc..d89682611e3f5c8b6dd32cde306eba14891f9964 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -279,6 +279,43 @@ private[tree] class LearningNode(
     }
   }
 
+  /**
+   * Get the node index corresponding to this data point.
+   * This function mimics prediction, passing an example from the root node down to a leaf
+   * or unsplit node; that node's index is returned.
+   *
+   * @param binnedFeatures  Binned feature vector for data point.
+   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+   * @return  Leaf index if the data point reaches a leaf.
+   *          Otherwise, last node reachable in tree matching this example.
+   *          Note: This is the global node index, i.e., the index used in the tree.
+   *                This index is different from the index used during training a particular
+   *                group of nodes on one call to [[findBestSplits()]].
+   */
+  def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = {
+    if (this.isLeaf || this.split.isEmpty) {
+      this.id
+    } else {
+      val split = this.split.get
+      val featureIndex = split.featureIndex
+      val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
+      if (this.leftChild.isEmpty) {
+        // Not yet split. Return next layer of nodes to train
+        if (splitLeft) {
+          LearningNode.leftChildIndex(this.id)
+        } else {
+          LearningNode.rightChildIndex(this.id)
+        }
+      } else {
+        if (splitLeft) {
+          this.leftChild.get.predictImpl(binnedFeatures, splits)
+        } else {
+          this.rightChild.get.predictImpl(binnedFeatures, splits)
+        }
+      }
+    }
+  }
+
 }
 
 private[tree] object LearningNode {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index c494556085e952c9402d50f71501fa0a12a00f34..96d5652857e08aa2bc2458f9045074499ebef011 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -205,47 +205,6 @@ private[ml] object RandomForest extends Logging {
     }
   }
 
-  /**
-   * Get the node index corresponding to this data point.
-   * This function mimics prediction, passing an example from the root node down to a leaf
-   * or unsplit node; that node's index is returned.
-   *
-   * @param node  Node in tree from which to classify the given data point.
-   * @param binnedFeatures  Binned feature vector for data point.
-   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
-   * @return  Leaf index if the data point reaches a leaf.
-   *          Otherwise, last node reachable in tree matching this example.
-   *          Note: This is the global node index, i.e., the index used in the tree.
-   *                This index is different from the index used during training a particular
-   *                group of nodes on one call to [[findBestSplits()]].
-   */
-  private def predictNodeIndex(
-      node: LearningNode,
-      binnedFeatures: Array[Int],
-      splits: Array[Array[Split]]): Int = {
-    if (node.isLeaf || node.split.isEmpty) {
-      node.id
-    } else {
-      val split = node.split.get
-      val featureIndex = split.featureIndex
-      val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
-      if (node.leftChild.isEmpty) {
-        // Not yet split. Return index from next layer of nodes to train
-        if (splitLeft) {
-          LearningNode.leftChildIndex(node.id)
-        } else {
-          LearningNode.rightChildIndex(node.id)
-        }
-      } else {
-        if (splitLeft) {
-          predictNodeIndex(node.leftChild.get, binnedFeatures, splits)
-        } else {
-          predictNodeIndex(node.rightChild.get, binnedFeatures, splits)
-        }
-      }
-    }
-  }
-
   /**
    * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
    *
@@ -453,8 +412,7 @@ private[ml] object RandomForest extends Logging {
         agg: Array[DTStatsAggregator],
         baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
       treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
-        val nodeIndex =
-          predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits)
+        val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
         nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
       }
       agg