From c7032290a3f0f5545aa4f0a9a144c62571344dc8 Mon Sep 17 00:00:00 2001
From: "Joseph K. Bradley" <joseph.kurata.bradley@gmail.com>
Date: Fri, 15 Aug 2014 14:50:10 -0700
Subject: [PATCH] [SPARK-3022] [SPARK-3041] [mllib] Call findBins once per
 level + unordered feature bug fix

DecisionTree improvements:
(1) TreePoint representation to avoid binning multiple times
(2) Bug fix: isSampleValid indexed bins incorrectly for unordered categorical features
(3) Timing for DecisionTree internals

Details:

(1) TreePoint representation to avoid binning multiple times

[https://issues.apache.org/jira/browse/SPARK-3022]

Added private[tree] TreePoint class for representing binned feature values.

The input RDD of LabeledPoint is converted to the TreePoint representation initially and then cached.  This avoids the previous problem of re-computing bins multiple times.

(2) Bug fix: isSampleValid indexed bins incorrectly for unordered categorical features

[https://issues.apache.org/jira/browse/SPARK-3041]

isSampleValid used to treat unordered categorical features incorrectly: It treated the bins as if indexed by featured values, rather than by subsets of values/categories.
* exhibited for unordered features (multi-class classification with categorical features of low arity)
* Fix: Index bins correctly for unordered categorical features.

(3) Timing for DecisionTree internals

Added tree/impl/TimeTracker.scala class which is private[tree] for now, for timing key parts of DT code.
Prints timing info via logDebug.

CC: mengxr manishamde chouqin  Very similar update, with one bug fix.  Many apologies for the conflicting update, but I hope that a few more optimizations I have on the way (which depend on this update) will prove valuable to you: SPARK-3042 and SPARK-3043

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes #1950 from jkbradley/dt-opt1 and squashes the following commits:

5f2dec2 [Joseph K. Bradley] Fixed scalastyle issue in TreePoint
6b5651e [Joseph K. Bradley] Updates based on code review.  1 major change: persisting to memory + disk, not just memory.
2d2aaaf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1
430d782 [Joseph K. Bradley] Added more debug info on binning error.  Added some docs.
d036089 [Joseph K. Bradley] Print timing info to logDebug.
e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private
8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up.  Removed debugging println calls from DecisionTree.  Made TreePoint extend Serialiable
a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1
0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree
3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging)
f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
a95bc22 [Joseph K. Bradley] timing for DecisionTree internals
---
 .../spark/mllib/tree/DecisionTree.scala       | 289 ++++++++----------
 .../mllib/tree/configuration/Strategy.scala   |  43 ++-
 .../spark/mllib/tree/impl/TimeTracker.scala   |  73 +++++
 .../spark/mllib/tree/impl/TreePoint.scala     | 201 ++++++++++++
 .../spark/mllib/tree/DecisionTreeSuite.scala  |  50 +--
 5 files changed, 449 insertions(+), 207 deletions(-)
 create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
 create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index bb50f07be5..2a3107a13e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -17,22 +17,24 @@
 
 package org.apache.spark.mllib.tree
 
-import org.apache.spark.api.java.JavaRDD
-
 import scala.collection.JavaConverters._
 
 import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.Logging
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
+import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity}
+import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint}
+import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.random.XORShiftRandom
 
+
 /**
  * :: Experimental ::
  * A class which implements a decision tree learning algorithm for classification and regression.
@@ -53,16 +55,27 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
    */
   def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
 
-    // Cache input RDD for speedup during multiple passes.
-    val retaggedInput = input.retag(classOf[LabeledPoint]).cache()
+    val timer = new TimeTracker()
+
+    timer.start("total")
+
+    timer.start("init")
+
+    val retaggedInput = input.retag(classOf[LabeledPoint])
     logDebug("algo = " + strategy.algo)
 
     // Find the splits and the corresponding bins (interval between the splits) using a sample
     // of the input data.
+    timer.start("findSplitsBins")
     val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy)
     val numBins = bins(0).length
+    timer.stop("findSplitsBins")
     logDebug("numBins = " + numBins)
 
+    // Cache input RDD for speedup during multiple passes.
+    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins)
+      .persist(StorageLevel.MEMORY_AND_DISK)
+
     // depth of the decision tree
     val maxDepth = strategy.maxDepth
     // the max number of nodes possible given the depth of the tree
@@ -76,7 +89,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
     // dummy value for top node (updated during first split calculation)
     val nodes = new Array[Node](maxNumNodes)
     // num features
-    val numFeatures = retaggedInput.take(1)(0).features.size
+    val numFeatures = treeInput.take(1)(0).binnedFeatures.size
 
     // Calculate level for single group construction
 
@@ -96,6 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
       (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0)
     logDebug("max level for single group = " + maxLevelForSingleGroup)
 
+    timer.stop("init")
+
     /*
      * The main idea here is to perform level-wise training of the decision tree nodes thus
      * reducing the passes over the data from l to log2(l) where l is the total number of nodes.
@@ -113,15 +128,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
       logDebug("#####################################")
 
       // Find best split for all nodes at a level.
-      val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities,
-        strategy, level, filters, splits, bins, maxLevelForSingleGroup)
+      timer.start("findBestSplits")
+      val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities,
+        strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer)
+      timer.stop("findBestSplits")
 
       for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
+        timer.start("extractNodeInfo")
         // Extract info for nodes at the current level.
         extractNodeInfo(nodeSplitStats, level, index, nodes)
+        timer.stop("extractNodeInfo")
+        timer.start("extractInfoForLowerLevels")
         // Extract info for nodes at the next lower level.
         extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
           filters)
+        timer.stop("extractInfoForLowerLevels")
         logDebug("final best split = " + nodeSplitStats._1)
       }
       require(math.pow(2, level) == splitsStatsForLevel.length)
@@ -144,6 +165,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
     // Build the full tree using the node info calculated in the level-wise best split calculations.
     topNode.build(nodes)
 
+    timer.stop("total")
+
+    logInfo("Internal timing for DecisionTree:")
+    logInfo(s"$timer")
+
     new DecisionTreeModel(topNode, strategy.algo)
   }
 
@@ -406,7 +432,7 @@ object DecisionTree extends Serializable with Logging {
    * Returns an array of optimal splits for all nodes at a given level. Splits the task into
    * multiple groups if the level-wise training task could lead to memory overflow.
    *
-   * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+   * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
    * @param parentImpurities Impurities for all parent nodes for the current level
    * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
    *                 parameters for constructing the DecisionTree
@@ -415,44 +441,45 @@ object DecisionTree extends Serializable with Logging {
    * @param splits possible splits for all features
    * @param bins possible bins for all features
    * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
-   * @return array of splits with best splits for all nodes at a given level.
+   * @return array (over nodes) of splits with best split for each node at a given level.
    */
   protected[tree] def findBestSplits(
-      input: RDD[LabeledPoint],
+      input: RDD[TreePoint],
       parentImpurities: Array[Double],
       strategy: Strategy,
       level: Int,
       filters: Array[List[Filter]],
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
-      maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = {
+      maxLevelForSingleGroup: Int,
+      timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
     // split into groups to avoid memory overflow during aggregation
     if (level > maxLevelForSingleGroup) {
       // When information for all nodes at a given level cannot be stored in memory,
       // the nodes are divided into multiple groups at each level with the number of groups
       // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
       // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
-      val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt
+      val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt
       logDebug("numGroups = " + numGroups)
       var bestSplits = new Array[(Split, InformationGainStats)](0)
       // Iterate over each group of nodes at a level.
       var groupIndex = 0
       while (groupIndex < numGroups) {
         val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
-          filters, splits, bins, numGroups, groupIndex)
+          filters, splits, bins, timer, numGroups, groupIndex)
         bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
         groupIndex += 1
       }
       bestSplits
     } else {
-      findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins)
+      findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, timer)
     }
   }
 
     /**
    * Returns an array of optimal splits for a group of nodes at a given level
    *
-   * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+   * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
    * @param parentImpurities Impurities for all parent nodes for the current level
    * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
    *                 parameters for constructing the DecisionTree
@@ -465,13 +492,14 @@ object DecisionTree extends Serializable with Logging {
    * @return array of splits with best splits for all nodes at a given level.
    */
   private def findBestSplitsPerGroup(
-      input: RDD[LabeledPoint],
+      input: RDD[TreePoint],
       parentImpurities: Array[Double],
       strategy: Strategy,
       level: Int,
       filters: Array[List[Filter]],
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
+      timer: TimeTracker,
       numGroups: Int = 1,
       groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
 
@@ -507,7 +535,7 @@ object DecisionTree extends Serializable with Logging {
     logDebug("numNodes = " + numNodes)
 
     // Find the number of features by looking at the first sample.
-    val numFeatures = input.first().features.size
+    val numFeatures = input.first().binnedFeatures.size
     logDebug("numFeatures = " + numFeatures)
 
     // numBins:  Number of bins = 1 + number of possible splits
@@ -542,33 +570,43 @@ object DecisionTree extends Serializable with Logging {
      * Find whether the sample is valid input for the current node, i.e., whether it passes through
      * all the filters for the current node.
      */
-    def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = {
+    def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): Boolean = {
       // leaf
       if ((level > 0) && (parentFilters.length == 0)) {
         return false
       }
 
       // Apply each filter and check sample validity. Return false when invalid condition found.
-      for (filter <- parentFilters) {
-        val features = labeledPoint.features
+      parentFilters.foreach { filter =>
         val featureIndex = filter.split.feature
-        val threshold = filter.split.threshold
         val comparison = filter.comparison
-        val categories = filter.split.categories
         val isFeatureContinuous = filter.split.featureType == Continuous
-        val feature =  features(featureIndex)
         if (isFeatureContinuous) {
+          val binId = treePoint.binnedFeatures(featureIndex)
+          val bin = bins(featureIndex)(binId)
+          val featureValue = bin.highSplit.threshold
+          val threshold = filter.split.threshold
           comparison match {
-            case -1 => if (feature > threshold) return false
-            case 1 => if (feature <= threshold) return false
+            case -1 => if (featureValue > threshold) return false
+            case 1 => if (featureValue <= threshold) return false
           }
         } else {
-          val containsFeature = categories.contains(feature)
+          val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+          val isSpaceSufficientForAllCategoricalSplits =
+            numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1
+          val isUnorderedFeature =
+            isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
+          val featureValue = if (isUnorderedFeature) {
+            treePoint.binnedFeatures(featureIndex)
+          } else {
+            val binId = treePoint.binnedFeatures(featureIndex)
+            bins(featureIndex)(binId).category
+          }
+          val containsFeature = filter.split.categories.contains(featureValue)
           comparison match {
             case -1 => if (!containsFeature) return false
             case 1 => if (containsFeature) return false
           }
-
         }
       }
 
@@ -576,103 +614,6 @@ object DecisionTree extends Serializable with Logging {
       true
     }
 
-    /**
-     * Find bin for one (labeledPoint, feature).
-     */
-    def findBin(
-        featureIndex: Int,
-        labeledPoint: LabeledPoint,
-        isFeatureContinuous: Boolean,
-        isSpaceSufficientForAllCategoricalSplits: Boolean): Int = {
-      val binForFeatures = bins(featureIndex)
-      val feature = labeledPoint.features(featureIndex)
-
-      /**
-       * Binary search helper method for continuous feature.
-       */
-      def binarySearchForBins(): Int = {
-        var left = 0
-        var right = binForFeatures.length - 1
-        while (left <= right) {
-          val mid = left + (right - left) / 2
-          val bin = binForFeatures(mid)
-          val lowThreshold = bin.lowSplit.threshold
-          val highThreshold = bin.highSplit.threshold
-          if ((lowThreshold < feature) && (highThreshold >= feature)) {
-            return mid
-          }
-          else if (lowThreshold >= feature) {
-            right = mid - 1
-          }
-          else {
-            left = mid + 1
-          }
-        }
-        -1
-      }
-
-      /**
-       * Sequential search helper method to find bin for categorical feature in multiclass
-       * classification. The category is returned since each category can belong to multiple
-       * splits. The actual left/right child allocation per split is performed in the
-       * sequential phase of the bin aggregate operation.
-       */
-      def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = {
-        labeledPoint.features(featureIndex).toInt
-      }
-
-      /**
-       * Sequential search helper method to find bin for categorical feature
-       * (for classification and regression).
-       */
-      def sequentialBinSearchForOrderedCategoricalFeature(): Int = {
-        val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-        val featureValue = labeledPoint.features(featureIndex)
-        var binIndex = 0
-        while (binIndex < featureCategories) {
-          val bin = bins(featureIndex)(binIndex)
-          val categories = bin.highSplit.categories
-          if (categories.contains(featureValue)) {
-            return binIndex
-          }
-          binIndex += 1
-        }
-        if (featureValue < 0 || featureValue >= featureCategories) {
-          throw new IllegalArgumentException(
-            s"DecisionTree given invalid data:" +
-            s" Feature $featureIndex is categorical with values in" +
-            s" {0,...,${featureCategories - 1}," +
-            s" but a data point gives it value $featureValue.\n" +
-            "  Bad data point: " + labeledPoint.toString)
-        }
-        -1
-      }
-
-      if (isFeatureContinuous) {
-        // Perform binary search for finding bin for continuous features.
-        val binIndex = binarySearchForBins()
-        if (binIndex == -1) {
-          throw new UnknownError("no bin was found for continuous variable.")
-        }
-        binIndex
-      } else {
-        // Perform sequential search to find bin for categorical features.
-        val binIndex = {
-          val isUnorderedFeature =
-            isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
-          if (isUnorderedFeature) {
-            sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
-          } else {
-            sequentialBinSearchForOrderedCategoricalFeature()
-          }
-        }
-        if (binIndex == -1) {
-          throw new UnknownError("no bin was found for categorical variable.")
-        }
-        binIndex
-      }
-    }
-
     /**
      * Finds bins for all nodes (and all features) at a given level.
      * For l nodes, k features the storage is as follows:
@@ -689,17 +630,17 @@ object DecisionTree extends Serializable with Logging {
      *            bin index for this labeledPoint
      *            (or InvalidBinIndex if labeledPoint is not handled by this node)
      */
-    def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = {
+    def findBinsForLevel(treePoint: TreePoint): Array[Double] = {
       // Calculate bin index and label per feature per node.
       val arr = new Array[Double](1 + (numFeatures * numNodes))
       // First element of the array is the label of the instance.
-      arr(0) = labeledPoint.label
+      arr(0) = treePoint.label
       // Iterate over nodes.
       var nodeIndex = 0
       while (nodeIndex < numNodes) {
         val parentFilters = findParentFilters(nodeIndex)
         // Find out whether the sample qualifies for the particular node.
-        val sampleValid = isSampleValid(parentFilters, labeledPoint)
+        val sampleValid = isSampleValid(parentFilters, treePoint)
         val shift = 1 + numFeatures * nodeIndex
         if (!sampleValid) {
           // Mark one bin as -1 is sufficient.
@@ -707,19 +648,7 @@ object DecisionTree extends Serializable with Logging {
         } else {
           var featureIndex = 0
           while (featureIndex < numFeatures) {
-            val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex)
-            val isFeatureContinuous = featureInfo.isEmpty
-            if (isFeatureContinuous) {
-              arr(shift + featureIndex)
-                = findBin(featureIndex, labeledPoint, isFeatureContinuous, false)
-            } else {
-              val featureCategories = featureInfo.get
-              val isSpaceSufficientForAllCategoricalSplits
-                = numBins > math.pow(2, featureCategories.toInt - 1) - 1
-              arr(shift + featureIndex)
-                = findBin(featureIndex, labeledPoint, isFeatureContinuous,
-                isSpaceSufficientForAllCategoricalSplits)
-            }
+            arr(shift + featureIndex) = treePoint.binnedFeatures(featureIndex)
             featureIndex += 1
           }
         }
@@ -728,7 +657,8 @@ object DecisionTree extends Serializable with Logging {
       arr
     }
 
-     // Find feature bins for all nodes at a level.
+    // Find feature bins for all nodes at a level.
+    timer.start("aggregation")
     val binMappedRDD = input.map(x => findBinsForLevel(x))
 
     /**
@@ -830,6 +760,8 @@ object DecisionTree extends Serializable with Logging {
       }
     }
 
+    val rightChildShift = numClasses * numBins * numFeatures * numNodes
+
     /**
      * Helper for binSeqOp.
      *
@@ -853,7 +785,6 @@ object DecisionTree extends Serializable with Logging {
         val validSignalIndex = 1 + numFeatures * nodeIndex
         val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
         if (isSampleValidForNode) {
-          val rightChildShift = numClasses * numBins * numFeatures * numNodes
           // actual class label
           val label = arr(0)
           // Iterate over all features.
@@ -912,7 +843,7 @@ object DecisionTree extends Serializable with Logging {
             val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3
             agg(aggIndex) = agg(aggIndex) + 1
             agg(aggIndex + 1) = agg(aggIndex + 1) + label
-            agg(aggIndex + 2) = agg(aggIndex + 2) + label*label
+            agg(aggIndex + 2) = agg(aggIndex + 2) + label * label
             featureIndex += 1
           }
         }
@@ -977,6 +908,7 @@ object DecisionTree extends Serializable with Logging {
     val binAggregates = {
       binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
     }
+    timer.stop("aggregation")
     logDebug("binAggregates.length = " + binAggregates.length)
 
     /**
@@ -1031,10 +963,17 @@ object DecisionTree extends Serializable with Logging {
           def indexOfLargestArrayElement(array: Array[Double]): Int = {
             val result = array.foldLeft(-1, Double.MinValue, 0) {
               case ((maxIndex, maxValue, currentIndex), currentValue) =>
-                if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1)
-                else (maxIndex, maxValue, currentIndex + 1)
+                if (currentValue > maxValue) {
+                  (currentIndex, currentValue, currentIndex + 1)
+                } else {
+                  (maxIndex, maxValue, currentIndex + 1)
+                }
+            }
+            if (result._1 < 0) {
+              throw new RuntimeException("DecisionTree internal error:" +
+                " calculateGainForSplit failed in indexOfLargestArrayElement")
             }
-            if (result._1 < 0) 0 else result._1
+            result._1
           }
 
           val predict = indexOfLargestArrayElement(leftRightCounts)
@@ -1057,6 +996,7 @@ object DecisionTree extends Serializable with Logging {
           val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
 
           new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+
         case Regression =>
           val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
           val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
@@ -1280,15 +1220,41 @@ object DecisionTree extends Serializable with Logging {
         nodeImpurity: Double): Array[Array[InformationGainStats]] = {
       val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
 
-      for (featureIndex <- 0 until numFeatures) {
-        for (splitIndex <- 0 until numBins - 1) {
+      var featureIndex = 0
+      while (featureIndex < numFeatures) {
+        val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
+        var splitIndex = 0
+        while (splitIndex < numSplitsForFeature) {
           gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
             splitIndex, rightNodeAgg, nodeImpurity)
+          splitIndex += 1
         }
+        featureIndex += 1
       }
       gains
     }
 
+    /**
+     * Get the number of splits for a feature.
+     */
+    def getNumSplitsForFeature(featureIndex: Int): Int = {
+      val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+      if (isFeatureContinuous) {
+        numBins - 1
+      } else {
+        // Categorical feature
+        val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+        val isSpaceSufficientForAllCategoricalSplits =
+          numBins > math.pow(2, featureCategories.toInt - 1) - 1
+        if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
+          math.pow(2.0, featureCategories - 1).toInt - 1
+        } else {
+          // Ordered features
+          featureCategories
+        }
+      }
+    }
+
     /**
      * Find the best split for a node.
      * @param binData Bin data slice for this node, given by getBinDataForNode.
@@ -1307,7 +1273,7 @@ object DecisionTree extends Serializable with Logging {
       // Calculate gains for all splits.
       val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
 
-      val (bestFeatureIndex,bestSplitIndex, gainStats) = {
+      val (bestFeatureIndex, bestSplitIndex, gainStats) = {
         // Initialize with infeasible values.
         var bestFeatureIndex = Int.MinValue
         var bestSplitIndex = Int.MinValue
@@ -1317,22 +1283,8 @@ object DecisionTree extends Serializable with Logging {
         while (featureIndex < numFeatures) {
           // Iterate over all splits.
           var splitIndex = 0
-          val maxSplitIndex: Double = {
-            val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
-            if (isFeatureContinuous) {
-              numBins - 1
-            } else { // Categorical feature
-              val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-              val isSpaceSufficientForAllCategoricalSplits
-                = numBins > math.pow(2, featureCategories.toInt - 1) - 1
-              if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
-                math.pow(2.0, featureCategories - 1).toInt - 1
-              } else { // Binary classification
-                featureCategories
-              }
-            }
-          }
-          while (splitIndex < maxSplitIndex) {
+          val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
+          while (splitIndex < numSplitsForFeature) {
             val gainStats = gains(featureIndex)(splitIndex)
             if (gainStats.gain > bestGainStats.gain) {
               bestGainStats = gainStats
@@ -1383,6 +1335,7 @@ object DecisionTree extends Serializable with Logging {
     }
 
     // Calculate best splits for all nodes at a given level
+    timer.start("chooseSplits")
     val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
     // Iterating over all nodes at this level
     var node = 0
@@ -1395,6 +1348,8 @@ object DecisionTree extends Serializable with Logging {
       bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
       node += 1
     }
+    timer.stop("chooseSplits")
+
     bestSplits
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index f31a503608..cfc8192a85 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -27,22 +27,30 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
 /**
  * :: Experimental ::
  * Stores all the configuration options for tree construction
- * @param algo classification or regression
- * @param impurity criterion used for information gain calculation
+ * @param algo  Learning goal.  Supported:
+ *              [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
+ *              [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * @param impurity Criterion used for information gain calculation.
+ *                 Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]],
+ *                  [[org.apache.spark.mllib.tree.impurity.Entropy]].
+ *                 Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]].
  * @param maxDepth Maximum depth of the tree.
  *                 E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * @param numClassesForClassification number of classes for classification. Default value is 2
- *                                    leads to binary classification
- * @param maxBins maximum number of bins used for splitting features
- * @param quantileCalculationStrategy algorithm for calculating quantiles
+ * @param numClassesForClassification Number of classes for classification.
+ *                                    (Ignored for regression.)
+ *                                    Default value is 2 (binary classification).
+ * @param maxBins Maximum number of bins used for discretizing continuous features and
+ *                for choosing how to split on features at each node.
+ *                More bins give higher granularity.
+ * @param quantileCalculationStrategy Algorithm for calculating quantiles.  Supported:
+   *                             [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]]
  * @param categoricalFeaturesInfo A map storing information about the categorical variables and the
  *                                number of discrete values they take. For example, an entry (n ->
  *                                k) implies the feature n is categorical with k categories 0,
  *                                1, 2, ... , k-1. It's important to note that features are
  *                                zero-indexed.
- * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is
+ * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
  *                      128 MB.
- *
  */
 @Experimental
 class Strategy (
@@ -64,20 +72,7 @@ class Strategy (
     = isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
 
   /**
-   * Java-friendly constructor.
-   *
-   * @param algo classification or regression
-   * @param impurity criterion used for information gain calculation
-   * @param maxDepth Maximum depth of the tree.
-   *                 E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
-   * @param numClassesForClassification number of classes for classification. Default value is 2
-   *                                    leads to binary classification
-   * @param maxBins maximum number of bins used for splitting features
-   * @param categoricalFeaturesInfo A map storing information about the categorical variables and
-   *                                the number of discrete values they take. For example, an entry
-   *                                (n -> k) implies the feature n is categorical with k categories
-   *                                0, 1, 2, ... , k-1. It's important to note that features are
-   *                                zero-indexed.
+   * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
    */
   def this(
       algo: Algo,
@@ -90,6 +85,10 @@ class Strategy (
       categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
   }
 
+  /**
+   * Check validity of parameters.
+   * Throws exception if invalid.
+   */
   private[tree] def assertValid(): Unit = {
     algo match {
       case Classification =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
new file mode 100644
index 0000000000..d215d68c42
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import scala.collection.mutable.{HashMap => MutableHashMap}
+
+import org.apache.spark.annotation.Experimental
+
+/**
+ * Time tracker implementation which holds labeled timers.
+ */
+@Experimental
+private[tree] class TimeTracker extends Serializable {
+
+  private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
+
+  private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
+
+  /**
+   * Starts a new timer, or re-starts a stopped timer.
+   */
+  def start(timerLabel: String): Unit = {
+    val currentTime = System.nanoTime()
+    if (starts.contains(timerLabel)) {
+      throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" +
+        s" timerLabel = $timerLabel before that timer was stopped.")
+    }
+    starts(timerLabel) = currentTime
+  }
+
+  /**
+   * Stops a timer and returns the elapsed time in seconds.
+   */
+  def stop(timerLabel: String): Double = {
+    val currentTime = System.nanoTime()
+    if (!starts.contains(timerLabel)) {
+      throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" +
+        s" timerLabel = $timerLabel, but that timer was not started.")
+    }
+    val elapsed = currentTime - starts(timerLabel)
+    starts.remove(timerLabel)
+    if (totals.contains(timerLabel)) {
+      totals(timerLabel) += elapsed
+    } else {
+      totals(timerLabel) = elapsed
+    }
+    elapsed / 1e9
+  }
+
+  /**
+   * Print all timing results in seconds.
+   */
+  override def toString: String = {
+    totals.map { case (label, elapsed) =>
+        s"  $label: ${elapsed / 1e9}"
+      }.mkString("\n")
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
new file mode 100644
index 0000000000..ccac1031fd
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.model.Bin
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * Internal representation of LabeledPoint for DecisionTree.
+ * This bins feature values based on a subsampled of data as follows:
+ *  (a) Continuous features are binned into ranges.
+ *  (b) Unordered categorical features are binned based on subsets of feature values.
+ *      "Unordered categorical features" are categorical features with low arity used in
+ *      multiclass classification.
+ *  (c) Ordered categorical features are binned based on feature values.
+ *      "Ordered categorical features" are categorical features with high arity,
+ *      or any categorical feature used in regression or binary classification.
+ *
+ * @param label  Label from LabeledPoint
+ * @param binnedFeatures  Binned feature values.
+ *                        Same length as LabeledPoint.features, but values are bin indices.
+ */
+private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
+  extends Serializable {
+}
+
+private[tree] object TreePoint {
+
+  /**
+   * Convert an input dataset into its TreePoint representation,
+   * binning feature values in preparation for DecisionTree training.
+   * @param input     Input dataset.
+   * @param strategy  DecisionTree training info, used for dataset metadata.
+   * @param bins      Bins for features, of size (numFeatures, numBins).
+   * @return  TreePoint dataset representation
+   */
+  def convertToTreeRDD(
+      input: RDD[LabeledPoint],
+      strategy: Strategy,
+      bins: Array[Array[Bin]]): RDD[TreePoint] = {
+    input.map { x =>
+      TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins,
+        strategy.categoricalFeaturesInfo)
+    }
+  }
+
+  /**
+   * Convert one LabeledPoint into its TreePoint representation.
+   * @param bins      Bins for features, of size (numFeatures, numBins).
+   * @param categoricalFeaturesInfo  Map over categorical features: feature index --> feature arity
+   */
+  private def labeledPointToTreePoint(
+      labeledPoint: LabeledPoint,
+      isMulticlassClassification: Boolean,
+      bins: Array[Array[Bin]],
+      categoricalFeaturesInfo: Map[Int, Int]): TreePoint = {
+
+    val numFeatures = labeledPoint.features.size
+    val numBins = bins(0).size
+    val arr = new Array[Int](numFeatures)
+    var featureIndex = 0
+    while (featureIndex < numFeatures) {
+      val featureInfo = categoricalFeaturesInfo.get(featureIndex)
+      val isFeatureContinuous = featureInfo.isEmpty
+      if (isFeatureContinuous) {
+        arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false,
+          bins, categoricalFeaturesInfo)
+      } else {
+        val featureCategories = featureInfo.get
+        val isSpaceSufficientForAllCategoricalSplits
+          = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+        val isUnorderedFeature =
+          isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
+        arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous,
+          isUnorderedFeature, bins, categoricalFeaturesInfo)
+      }
+      featureIndex += 1
+    }
+
+    new TreePoint(labeledPoint.label, arr)
+  }
+
+  /**
+   * Find bin for one (labeledPoint, feature).
+   *
+   * @param isUnorderedFeature  (only applies if feature is categorical)
+   * @param bins   Bins for features, of size (numFeatures, numBins).
+   * @param categoricalFeaturesInfo  Map over categorical features: feature index --> feature arity
+   */
+  private def findBin(
+      featureIndex: Int,
+      labeledPoint: LabeledPoint,
+      isFeatureContinuous: Boolean,
+      isUnorderedFeature: Boolean,
+      bins: Array[Array[Bin]],
+      categoricalFeaturesInfo: Map[Int, Int]): Int = {
+
+    /**
+     * Binary search helper method for continuous feature.
+     */
+    def binarySearchForBins(): Int = {
+      val binForFeatures = bins(featureIndex)
+      val feature = labeledPoint.features(featureIndex)
+      var left = 0
+      var right = binForFeatures.length - 1
+      while (left <= right) {
+        val mid = left + (right - left) / 2
+        val bin = binForFeatures(mid)
+        val lowThreshold = bin.lowSplit.threshold
+        val highThreshold = bin.highSplit.threshold
+        if ((lowThreshold < feature) && (highThreshold >= feature)) {
+          return mid
+        } else if (lowThreshold >= feature) {
+          right = mid - 1
+        } else {
+          left = mid + 1
+        }
+      }
+      -1
+    }
+
+    /**
+     * Sequential search helper method to find bin for categorical feature in multiclass
+     * classification. The category is returned since each category can belong to multiple
+     * splits. The actual left/right child allocation per split is performed in the
+     * sequential phase of the bin aggregate operation.
+     */
+    def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = {
+      labeledPoint.features(featureIndex).toInt
+    }
+
+    /**
+     * Sequential search helper method to find bin for categorical feature
+     * (for classification and regression).
+     */
+    def sequentialBinSearchForOrderedCategoricalFeature(): Int = {
+      val featureCategories = categoricalFeaturesInfo(featureIndex)
+      val featureValue = labeledPoint.features(featureIndex)
+      var binIndex = 0
+      while (binIndex < featureCategories) {
+        val bin = bins(featureIndex)(binIndex)
+        val categories = bin.highSplit.categories
+        if (categories.contains(featureValue)) {
+          return binIndex
+        }
+        binIndex += 1
+      }
+      if (featureValue < 0 || featureValue >= featureCategories) {
+        throw new IllegalArgumentException(
+          s"DecisionTree given invalid data:" +
+            s" Feature $featureIndex is categorical with values in" +
+            s" {0,...,${featureCategories - 1}," +
+            s" but a data point gives it value $featureValue.\n" +
+            "  Bad data point: " + labeledPoint.toString)
+      }
+      -1
+    }
+
+    if (isFeatureContinuous) {
+      // Perform binary search for finding bin for continuous features.
+      val binIndex = binarySearchForBins()
+      if (binIndex == -1) {
+        throw new RuntimeException("No bin was found for continuous feature." +
+          " This error can occur when given invalid data values (such as NaN)." +
+          s" Feature index: $featureIndex.  Feature value: ${labeledPoint.features(featureIndex)}")
+      }
+      binIndex
+    } else {
+      // Perform sequential search to find bin for categorical features.
+      val binIndex = if (isUnorderedFeature) {
+          sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
+        } else {
+          sequentialBinSearchForOrderedCategoricalFeature()
+        }
+      if (binIndex == -1) {
+        throw new RuntimeException("No bin was found for categorical feature." +
+          " This error can occur when given invalid data values (such as NaN)." +
+          s" Feature index: $featureIndex.  Feature value: ${labeledPoint.features(featureIndex)}")
+      }
+      binIndex
+    }
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 70ca7c8a26..a5c49a38dc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -21,11 +21,12 @@ import scala.collection.JavaConverters._
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
-import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
+import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
+import org.apache.spark.mllib.tree.impl.TreePoint
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.LocalSparkContext
 import org.apache.spark.mllib.regression.LabeledPoint
@@ -41,7 +42,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       prediction != expected.label
     }
     val accuracy = (input.length - numOffPredictions).toDouble / input.length
-    assert(accuracy >= requiredAccuracy)
+    assert(accuracy >= requiredAccuracy,
+      s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
   }
 
   def validateRegressor(
@@ -54,7 +56,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       err * err
     }.sum
     val mse = squaredError / input.length
-    assert(mse <= requiredMSE)
+    assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
   }
 
   test("split and bin calculation") {
@@ -427,7 +429,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
-    val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
       Array[List[Filter]](), splits, bins, 10)
 
     val split = bestSplits(0)._1
@@ -454,7 +457,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
-    val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
       Array[List[Filter]](), splits, bins, 10)
 
     val split = bestSplits(0)._1
@@ -499,7 +503,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins(0).length === 100)
 
-    val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
       Array[List[Filter]](), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
@@ -521,7 +526,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins(0).length === 100)
 
-    val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
       Array[List[Filter]](), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
@@ -544,7 +550,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins(0).length === 100)
 
-    val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
       Array[List[Filter]](), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
@@ -567,7 +574,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins(0).length === 100)
 
-    val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
       Array[List[Filter]](), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
@@ -596,7 +604,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val parentImpurities = Array(0.5, 0.5, 0.5)
 
     // Single group second level tree construction.
-    val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters,
+    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters,
       splits, bins, 10)
     assert(bestSplits.length === 2)
     assert(bestSplits(0)._2.gain > 0)
@@ -604,7 +613,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
     // level tree construction.
-    val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1,
+    val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1,
       filters, splits, bins, 0)
     assert(bestSplitsWithGroups.length === 2)
     assert(bestSplitsWithGroups(0)._2.gain > 0)
@@ -630,7 +639,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
     assert(strategy.isMulticlassClassification)
     val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
       Array[List[Filter]](), splits, bins, 10)
 
     assert(bestSplits.length === 1)
@@ -689,7 +699,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(model.depth === 1)
 
     val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
       Array[List[Filter]](), splits, bins, 10)
 
     assert(bestSplits.length === 1)
@@ -714,7 +725,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     validateClassifier(model, arr, 0.9)
 
     val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
       Array[List[Filter]](), splits, bins, 10)
 
     assert(bestSplits.length === 1)
@@ -738,7 +750,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     validateClassifier(model, arr, 0.9)
 
     val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
       Array[List[Filter]](), splits, bins, 10)
 
     assert(bestSplits.length === 1)
@@ -757,7 +770,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
     assert(strategy.isMulticlassClassification)
     val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
       Array[List[Filter]](), splits, bins, 10)
 
     assert(bestSplits.length === 1)
-- 
GitLab