From cca2258685147be6c950c9f5c4e50eaa1e090714 Mon Sep 17 00:00:00 2001
From: Luvsandondov Lkhamsuren <lkhamsurenl@gmail.com>
Date: Sat, 17 Oct 2015 10:07:42 -0700
Subject: [PATCH] [SPARK-9963] [ML] RandomForest cleanup: replace
 predictNodeIndex with predictImpl

predictNodeIndex is moved to LearningNode and renamed predictImpl for consistency with Node.predictImpl

Author: Luvsandondov Lkhamsuren <lkhamsurenl@gmail.com>

Closes #8609 from lkhamsurenl/SPARK-9963.
---
 .../scala/org/apache/spark/ml/tree/Node.scala | 37 ++++++++++++++++
 .../spark/ml/tree/impl/RandomForest.scala     | 44 +------------------
 2 files changed, 38 insertions(+), 43 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index cd24931293..d89682611e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -279,6 +279,43 @@ private[tree] class LearningNode(
     }
   }
 
+  /**
+   * Get the node index corresponding to this data point.
+   * This function mimics prediction, passing an example from the root node down to a leaf
+   * or unsplit node; that node's index is returned.
+   *
+   * @param binnedFeatures  Binned feature vector for data point.
+   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+   * @return  Leaf index if the data point reaches a leaf.
+   *          Otherwise, last node reachable in tree matching this example.
+   *          Note: This is the global node index, i.e., the index used in the tree.
+   *                This index is different from the index used during training a particular
+   *                group of nodes on one call to [[findBestSplits()]].
+   */
+  def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = {
+    if (this.isLeaf || this.split.isEmpty) {
+      this.id
+    } else {
+      val split = this.split.get
+      val featureIndex = split.featureIndex
+      val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
+      if (this.leftChild.isEmpty) {
+        // Not yet split. Return next layer of nodes to train
+        if (splitLeft) {
+          LearningNode.leftChildIndex(this.id)
+        } else {
+          LearningNode.rightChildIndex(this.id)
+        }
+      } else {
+        if (splitLeft) {
+          this.leftChild.get.predictImpl(binnedFeatures, splits)
+        } else {
+          this.rightChild.get.predictImpl(binnedFeatures, splits)
+        }
+      }
+    }
+  }
+
 }
 
 private[tree] object LearningNode {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index c494556085..96d5652857 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -205,47 +205,6 @@ private[ml] object RandomForest extends Logging {
     }
   }
 
-  /**
-   * Get the node index corresponding to this data point.
-   * This function mimics prediction, passing an example from the root node down to a leaf
-   * or unsplit node; that node's index is returned.
-   *
-   * @param node  Node in tree from which to classify the given data point.
-   * @param binnedFeatures  Binned feature vector for data point.
-   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
-   * @return  Leaf index if the data point reaches a leaf.
-   *          Otherwise, last node reachable in tree matching this example.
-   *          Note: This is the global node index, i.e., the index used in the tree.
-   *                This index is different from the index used during training a particular
-   *                group of nodes on one call to [[findBestSplits()]].
-   */
-  private def predictNodeIndex(
-      node: LearningNode,
-      binnedFeatures: Array[Int],
-      splits: Array[Array[Split]]): Int = {
-    if (node.isLeaf || node.split.isEmpty) {
-      node.id
-    } else {
-      val split = node.split.get
-      val featureIndex = split.featureIndex
-      val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
-      if (node.leftChild.isEmpty) {
-        // Not yet split. Return index from next layer of nodes to train
-        if (splitLeft) {
-          LearningNode.leftChildIndex(node.id)
-        } else {
-          LearningNode.rightChildIndex(node.id)
-        }
-      } else {
-        if (splitLeft) {
-          predictNodeIndex(node.leftChild.get, binnedFeatures, splits)
-        } else {
-          predictNodeIndex(node.rightChild.get, binnedFeatures, splits)
-        }
-      }
-    }
-  }
-
   /**
    * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
    *
@@ -453,8 +412,7 @@ private[ml] object RandomForest extends Logging {
         agg: Array[DTStatsAggregator],
         baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
       treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
-        val nodeIndex =
-          predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits)
+        val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
         nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
       }
       agg
-- 
GitLab