diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 2a3107a13e91686f9fbd1ccb85fe21b7ab371320..6b9a8f72c244efa07f7da9ce40ed963ecab96a5d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint}
+import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint}
 import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.rdd.RDD
@@ -62,43 +62,38 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
     timer.start("init")
 
     val retaggedInput = input.retag(classOf[LabeledPoint])
+    val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy)
     logDebug("algo = " + strategy.algo)
 
     // Find the splits and the corresponding bins (interval between the splits) using a sample
     // of the input data.
     timer.start("findSplitsBins")
-    val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
     val numBins = bins(0).length
     timer.stop("findSplitsBins")
     logDebug("numBins = " + numBins)
 
+    // Bin feature values (TreePoint representation).
     // Cache input RDD for speedup during multiple passes.
-    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins)
+    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
       .persist(StorageLevel.MEMORY_AND_DISK)
 
+    val numFeatures = metadata.numFeatures
     // depth of the decision tree
     val maxDepth = strategy.maxDepth
     // the max number of nodes possible given the depth of the tree
-    val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1
-    // Initialize an array to hold filters applied to points for each node.
-    val filters = new Array[List[Filter]](maxNumNodes)
-    // The filter at the top node is an empty list.
-    filters(0) = List()
+    val maxNumNodes = (2 << maxDepth) - 1
     // Initialize an array to hold parent impurity calculations for each node.
     val parentImpurities = new Array[Double](maxNumNodes)
     // dummy value for top node (updated during first split calculation)
     val nodes = new Array[Node](maxNumNodes)
-    // num features
-    val numFeatures = treeInput.take(1)(0).binnedFeatures.size
 
     // Calculate level for single group construction
 
     // Max memory usage for aggregates
     val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
     logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
-    val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins,
-      strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures,
-      strategy.algo)
+    val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins)
 
     logDebug("numElementsPerNode = " + numElementsPerNode)
     val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
@@ -114,9 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
     /*
      * The main idea here is to perform level-wise training of the decision tree nodes thus
      * reducing the passes over the data from l to log2(l) where l is the total number of nodes.
-     * Each data sample is checked for validity w.r.t to each node at a given level -- i.e.,
-     * the sample is only used for the split calculation at the node if the sampled would have
-     * still survived the filters of the parent nodes.
+     * Each data sample is handled by a particular node at that level (or it reaches a leaf
+     * beforehand and is not used in later levels.
      */
 
     var level = 0
@@ -130,22 +124,37 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
       // Find best split for all nodes at a level.
       timer.start("findBestSplits")
       val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities,
-        strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer)
+        metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
       timer.stop("findBestSplits")
 
+      val levelNodeIndexOffset = (1 << level) - 1
       for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
+        val nodeIndex = levelNodeIndexOffset + index
+        val isLeftChild = level != 0 && nodeIndex % 2 == 1
+        val parentNodeIndex = if (isLeftChild) { // -1 for root node
+            (nodeIndex - 1) / 2
+          } else {
+            (nodeIndex - 2) / 2
+          }
+        // Extract info for this node (index) at the current level.
         timer.start("extractNodeInfo")
-        // Extract info for nodes at the current level.
         extractNodeInfo(nodeSplitStats, level, index, nodes)
         timer.stop("extractNodeInfo")
-        timer.start("extractInfoForLowerLevels")
+        if (level != 0) {
+          // Set parent.
+          if (isLeftChild) {
+            nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex))
+          } else {
+            nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex))
+          }
+        }
         // Extract info for nodes at the next lower level.
-        extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
-          filters)
+        timer.start("extractInfoForLowerLevels")
+        extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities)
         timer.stop("extractInfoForLowerLevels")
         logDebug("final best split = " + nodeSplitStats._1)
       }
-      require(math.pow(2, level) == splitsStatsForLevel.length)
+      require((1 << level) == splitsStatsForLevel.length)
       // Check whether all the nodes at the current level at leaves.
       val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
       logDebug("all leaf = " + allLeaf)
@@ -183,7 +192,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
       nodes: Array[Node]): Unit = {
     val split = nodeSplitStats._1
     val stats = nodeSplitStats._2
-    val nodeIndex = math.pow(2, level).toInt - 1 + index
+    val nodeIndex = (1 << level) - 1 + index
     val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
     val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
     logDebug("Node = " + node)
@@ -198,31 +207,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
       index: Int,
       maxDepth: Int,
       nodeSplitStats: (Split, InformationGainStats),
-      parentImpurities: Array[Double],
-      filters: Array[List[Filter]]): Unit = {
-    // 0 corresponds to the left child node and 1 corresponds to the right child node.
-    var i = 0
-    while (i <= 1) {
-     // Calculate the index of the node from the node level and the index at the current level.
-      val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i
-      if (level < maxDepth) {
-        val impurity = if (i == 0) {
-          nodeSplitStats._2.leftImpurity
-        } else {
-          nodeSplitStats._2.rightImpurity
-        }
-        logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
-        // noting the parent impurities
-        parentImpurities(nodeIndex) = impurity
-        // noting the parents filters for the child nodes
-        val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
-        filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2)
-        for (filter <- filters(nodeIndex)) {
-          logDebug("Filter = " + filter)
-        }
-      }
-      i += 1
+      parentImpurities: Array[Double]): Unit = {
+
+    if (level >= maxDepth) {
+      return
     }
+
+    val leftNodeIndex = (2 << level) - 1 + 2 * index
+    val leftImpurity = nodeSplitStats._2.leftImpurity
+    logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity)
+    parentImpurities(leftNodeIndex) = leftImpurity
+
+    val rightNodeIndex = leftNodeIndex + 1
+    val rightImpurity = nodeSplitStats._2.rightImpurity
+    logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity)
+    parentImpurities(rightNodeIndex) = rightImpurity
   }
 }
 
@@ -434,10 +433,8 @@ object DecisionTree extends Serializable with Logging {
    *
    * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
    * @param parentImpurities Impurities for all parent nodes for the current level
-   * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
-   *                 parameters for constructing the DecisionTree
+   * @param metadata Learning and dataset metadata
    * @param level Level of the tree
-   * @param filters Filters for all nodes at a given level
    * @param splits possible splits for all features
    * @param bins possible bins for all features
    * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
@@ -446,9 +443,9 @@ object DecisionTree extends Serializable with Logging {
   protected[tree] def findBestSplits(
       input: RDD[TreePoint],
       parentImpurities: Array[Double],
-      strategy: Strategy,
+      metadata: DecisionTreeMetadata,
       level: Int,
-      filters: Array[List[Filter]],
+      nodes: Array[Node],
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
       maxLevelForSingleGroup: Int,
@@ -459,34 +456,32 @@ object DecisionTree extends Serializable with Logging {
       // the nodes are divided into multiple groups at each level with the number of groups
       // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
       // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
-      val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt
+      val numGroups = 1 << level - maxLevelForSingleGroup
       logDebug("numGroups = " + numGroups)
       var bestSplits = new Array[(Split, InformationGainStats)](0)
       // Iterate over each group of nodes at a level.
       var groupIndex = 0
       while (groupIndex < numGroups) {
-        val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
-          filters, splits, bins, timer, numGroups, groupIndex)
+        val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level,
+          nodes, splits, bins, timer, numGroups, groupIndex)
         bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
         groupIndex += 1
       }
       bestSplits
     } else {
-      findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, timer)
+      findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer)
     }
   }
 
-    /**
+  /**
    * Returns an array of optimal splits for a group of nodes at a given level
    *
    * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
    * @param parentImpurities Impurities for all parent nodes for the current level
-   * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
-   *                 parameters for constructing the DecisionTree
+   * @param metadata Learning and dataset metadata
    * @param level Level of the tree
-   * @param filters Filters for all nodes at a given level
    * @param splits possible splits for all features
-   * @param bins possible bins for all features
+   * @param bins possible bins for all features, indexed as (numFeatures)(numBins)
    * @param numGroups total number of node groups at the current level. Default value is set to 1.
    * @param groupIndex index of the node group being processed. Default value is set to 0.
    * @return array of splits with best splits for all nodes at a given level.
@@ -494,9 +489,9 @@ object DecisionTree extends Serializable with Logging {
   private def findBestSplitsPerGroup(
       input: RDD[TreePoint],
       parentImpurities: Array[Double],
-      strategy: Strategy,
+      metadata: DecisionTreeMetadata,
       level: Int,
-      filters: Array[List[Filter]],
+      nodes: Array[Node],
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
       timer: TimeTracker,
@@ -515,7 +510,7 @@ object DecisionTree extends Serializable with Logging {
      * We use a bin-wise best split computation strategy instead of a straightforward best split
      * computation strategy. Instead of analyzing each sample for contribution to the left/right
      * child node impurity of every split, we first categorize each feature of a sample into a
-     * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin,
+     * bin. Each bin is an interval between a low and high split. Since each split, and thus bin,
      * is ordered (read ordering for categorical variables in the findSplitsBins method),
      * we exploit this structure to calculate aggregates for bins and then use these aggregates
      * to calculate information gain for each split.
@@ -531,160 +526,124 @@ object DecisionTree extends Serializable with Logging {
 
     // numNodes:  Number of nodes in this (level of tree, group),
     //            where nodes at deeper (larger) levels may be divided into groups.
-    val numNodes = math.pow(2, level).toInt / numGroups
+    val numNodes = (1 << level) / numGroups
     logDebug("numNodes = " + numNodes)
 
     // Find the number of features by looking at the first sample.
-    val numFeatures = input.first().binnedFeatures.size
+    val numFeatures = metadata.numFeatures
     logDebug("numFeatures = " + numFeatures)
 
     // numBins:  Number of bins = 1 + number of possible splits
     val numBins = bins(0).length
     logDebug("numBins = " + numBins)
 
-    val numClasses = strategy.numClassesForClassification
+    val numClasses = metadata.numClasses
     logDebug("numClasses = " + numClasses)
 
-    val isMulticlassClassification = strategy.isMulticlassClassification
-    logDebug("isMulticlassClassification = " + isMulticlassClassification)
+    val isMulticlass = metadata.isMulticlass
+    logDebug("isMulticlass = " + isMulticlass)
 
-    val isMulticlassClassificationWithCategoricalFeatures
-      = strategy.isMulticlassWithCategoricalFeatures
-    logDebug("isMultiClassWithCategoricalFeatures = " +
-      isMulticlassClassificationWithCategoricalFeatures)
+    val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures
+    logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures)
 
     // shift when more than one group is used at deep tree level
     val groupShift = numNodes * groupIndex
 
-    /** Find the filters used before reaching the current code. */
-    def findParentFilters(nodeIndex: Int): List[Filter] = {
-      if (level == 0) {
-        List[Filter]()
-      } else {
-        val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift
-        filters(nodeFilterIndex)
-      }
-    }
-
     /**
-     * Find whether the sample is valid input for the current node, i.e., whether it passes through
-     * all the filters for the current node.
+     * Get the node index corresponding to this data point.
+     * This function mimics prediction, passing an example from the root node down to a node
+     * at the current level being trained; that node's index is returned.
+     *
+     * @return  Leaf index if the data point reaches a leaf.
+     *          Otherwise, last node reachable in tree matching this example.
      */
-    def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): Boolean = {
-      // leaf
-      if ((level > 0) && (parentFilters.length == 0)) {
-        return false
-      }
-
-      // Apply each filter and check sample validity. Return false when invalid condition found.
-      parentFilters.foreach { filter =>
-        val featureIndex = filter.split.feature
-        val comparison = filter.comparison
-        val isFeatureContinuous = filter.split.featureType == Continuous
-        if (isFeatureContinuous) {
-          val binId = treePoint.binnedFeatures(featureIndex)
-          val bin = bins(featureIndex)(binId)
-          val featureValue = bin.highSplit.threshold
-          val threshold = filter.split.threshold
-          comparison match {
-            case -1 => if (featureValue > threshold) return false
-            case 1 => if (featureValue <= threshold) return false
+    def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = {
+      if (node.isLeaf) {
+        node.id
+      } else {
+        val featureIndex = node.split.get.feature
+        val splitLeft = node.split.get.featureType match {
+          case Continuous => {
+            val binIndex = binnedFeatures(featureIndex)
+            val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
+            // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
+            // We do not need to check lowSplit since bins are separated by splits.
+            featureValueUpperBound <= node.split.get.threshold
           }
-        } else {
-          val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-          val isSpaceSufficientForAllCategoricalSplits =
-            numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1
-          val isUnorderedFeature =
-            isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
-          val featureValue = if (isUnorderedFeature) {
-            treePoint.binnedFeatures(featureIndex)
+          case Categorical => {
+            val featureValue = if (metadata.isUnordered(featureIndex)) {
+                binnedFeatures(featureIndex)
+              } else {
+                val binIndex = binnedFeatures(featureIndex)
+                bins(featureIndex)(binIndex).category
+              }
+            node.split.get.categories.contains(featureValue)
+          }
+          case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")
+        }
+        if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
+          // Return index from next layer of nodes to train
+          if (splitLeft) {
+            node.id * 2 + 1 // left
           } else {
-            val binId = treePoint.binnedFeatures(featureIndex)
-            bins(featureIndex)(binId).category
+            node.id * 2 + 2 // right
           }
-          val containsFeature = filter.split.categories.contains(featureValue)
-          comparison match {
-            case -1 => if (!containsFeature) return false
-            case 1 => if (containsFeature) return false
+        } else {
+          if (splitLeft) {
+            predictNodeIndex(node.leftNode.get, binnedFeatures)
+          } else {
+            predictNodeIndex(node.rightNode.get, binnedFeatures)
           }
         }
       }
+    }
 
-      // Return true when the sample is valid for all filters.
-      true
+    def nodeIndexToLevel(idx: Int): Int = {
+      if (idx == 0) {
+        0
+      } else {
+        math.floor(math.log(idx) / math.log(2)).toInt
+      }
     }
 
+    // Used for treePointToNodeIndex
+    val levelOffset = (1 << level) - 1
+
     /**
-     * Finds bins for all nodes (and all features) at a given level.
-     * For l nodes, k features the storage is as follows:
-     * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk,
-     * where b_ij is an integer between 0 and numBins - 1 for regressions and binary
-     * classification and the categorical feature value in  multiclass classification.
-     * Invalid sample is denoted by noting bin for feature 1 as -1.
-     *
-     * For unordered features, the "bin index" returned is actually the feature value (category).
-     *
-     * @return  Array of size 1 + numFeatures * numNodes, where
-     *          arr(0) = label for labeledPoint, and
-     *          arr(1 + numFeatures * nodeIndex + featureIndex) =
-     *            bin index for this labeledPoint
-     *            (or InvalidBinIndex if labeledPoint is not handled by this node)
+     * Find the node index for the given example.
+     * Nodes are indexed from 0 at the start of this (level, group).
+     * If the example does not reach this level, returns a value < 0.
      */
-    def findBinsForLevel(treePoint: TreePoint): Array[Double] = {
-      // Calculate bin index and label per feature per node.
-      val arr = new Array[Double](1 + (numFeatures * numNodes))
-      // First element of the array is the label of the instance.
-      arr(0) = treePoint.label
-      // Iterate over nodes.
-      var nodeIndex = 0
-      while (nodeIndex < numNodes) {
-        val parentFilters = findParentFilters(nodeIndex)
-        // Find out whether the sample qualifies for the particular node.
-        val sampleValid = isSampleValid(parentFilters, treePoint)
-        val shift = 1 + numFeatures * nodeIndex
-        if (!sampleValid) {
-          // Mark one bin as -1 is sufficient.
-          arr(shift) = InvalidBinIndex
-        } else {
-          var featureIndex = 0
-          while (featureIndex < numFeatures) {
-            arr(shift + featureIndex) = treePoint.binnedFeatures(featureIndex)
-            featureIndex += 1
-          }
-        }
-        nodeIndex += 1
+    def treePointToNodeIndex(treePoint: TreePoint): Int = {
+      if (level == 0) {
+        0
+      } else {
+        val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.binnedFeatures)
+        // Get index for this (level, group).
+        globalNodeIndex - levelOffset - groupShift
       }
-      arr
     }
 
-    // Find feature bins for all nodes at a level.
-    timer.start("aggregation")
-    val binMappedRDD = input.map(x => findBinsForLevel(x))
-
     /**
      * Increment aggregate in location for (node, feature, bin, label).
      *
-     * @param arr  Bin mapping from findBinsForLevel.  arr(0) stores the class label.
-     *             Array of size 1 + (numFeatures * numNodes).
+     * @param treePoint  Data point being aggregated.
      * @param agg  Array storing aggregate calculation, of size:
      *             numClasses * numBins * numFeatures * numNodes.
      *             Indexed by (node, feature, bin, label) where label is the least significant bit.
+     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
      */
     def updateBinForOrderedFeature(
-        arr: Array[Double],
+        treePoint: TreePoint,
         agg: Array[Double],
         nodeIndex: Int,
-        label: Double,
         featureIndex: Int): Unit = {
-      // Find the bin index for this feature.
-      val arrShift = 1 + numFeatures * nodeIndex
-      val arrIndex = arrShift + featureIndex
       // Update the left or right count for one bin.
       val aggIndex =
         numClasses * numBins * numFeatures * nodeIndex +
         numClasses * numBins * featureIndex +
-        numClasses * arr(arrIndex).toInt +
-        label.toInt
+        numClasses * treePoint.binnedFeatures(featureIndex) +
+        treePoint.label.toInt
       agg(aggIndex) += 1
     }
 
@@ -693,8 +652,8 @@ object DecisionTree extends Serializable with Logging {
      * where [bins] ranges over all bins.
      * Updates left or right side of aggregate depending on split.
      *
-     * @param arr  arr(0) = label.
-     *             arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category)
+     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
+     * @param treePoint  Data point being aggregated.
      * @param agg  Indexed by (left/right, node, feature, bin, label)
      *             where label is the least significant bit.
      *             The left/right specifier is a 0/1 index indicating left/right child info.
@@ -703,21 +662,18 @@ object DecisionTree extends Serializable with Logging {
     def updateBinForUnorderedFeature(
         nodeIndex: Int,
         featureIndex: Int,
-        arr: Array[Double],
-        label: Double,
+        treePoint: TreePoint,
         agg: Array[Double],
         rightChildShift: Int): Unit = {
-      // Find the bin index for this feature.
-      val arrIndex = 1 + numFeatures * nodeIndex + featureIndex
-      val featureValue = arr(arrIndex).toInt
+      val featureValue = treePoint.binnedFeatures(featureIndex)
       // Update the left or right count for one bin.
       val aggShift =
         numClasses * numBins * numFeatures * nodeIndex +
         numClasses * numBins * featureIndex +
-        label.toInt
+        treePoint.label.toInt
       // Find all matching bins and increment their values
-      val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-      val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
+      val featureCategories = metadata.featureArity(featureIndex)
+      val numCategoricalBins = (1 << featureCategories - 1) - 1
       var binIndex = 0
       while (binIndex < numCategoricalBins) {
         val aggIndex = aggShift + binIndex * numClasses
@@ -733,30 +689,21 @@ object DecisionTree extends Serializable with Logging {
     /**
      * Helper for binSeqOp.
      *
-     * @param arr  Bin mapping from findBinsForLevel. arr(0) stores the class label.
-     *             Array of size 1 + (numFeatures * numNodes).
      * @param agg  Array storing aggregate calculation, of size:
      *             numClasses * numBins * numFeatures * numNodes.
      *             Indexed by (node, feature, bin, label) where label is the least significant bit.
+     * @param treePoint  Data point being aggregated.
+     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
      */
-    def binaryOrNotCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = {
-      // Iterate over all nodes.
-      var nodeIndex = 0
-      while (nodeIndex < numNodes) {
-        // Check whether the instance was valid for this nodeIndex.
-        val validSignalIndex = 1 + numFeatures * nodeIndex
-        val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
-        if (isSampleValidForNode) {
-          // actual class label
-          val label = arr(0)
-          // Iterate over all features.
-          var featureIndex = 0
-          while (featureIndex < numFeatures) {
-            updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
-            featureIndex += 1
-          }
-        }
-        nodeIndex += 1
+    def binaryOrNotCategoricalBinSeqOp(
+        agg: Array[Double],
+        treePoint: TreePoint,
+        nodeIndex: Int): Unit = {
+      // Iterate over all features.
+      var featureIndex = 0
+      while (featureIndex < numFeatures) {
+        updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex)
+        featureIndex += 1
       }
     }
 
@@ -765,49 +712,28 @@ object DecisionTree extends Serializable with Logging {
     /**
      * Helper for binSeqOp.
      *
-     * @param arr  Bin mapping from findBinsForLevel. arr(0) stores the class label.
-     *             Array of size 1 + (numFeatures * numNodes).
-     *             For ordered features,
-     *               arr(1 + featureIndex + nodeIndex * numFeatures) = bin index.
-     *             For unordered features,
-     *               arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category).
      * @param agg  Array storing aggregate calculation.
      *             For ordered features, this is of size:
      *               numClasses * numBins * numFeatures * numNodes.
      *             For unordered features, this is of size:
      *               2 * numClasses * numBins * numFeatures * numNodes.
+     * @param treePoint   Data point being aggregated.
+     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
      */
-    def multiclassWithCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = {
-      // Iterate over all nodes.
-      var nodeIndex = 0
-      while (nodeIndex < numNodes) {
-        // Check whether the instance was valid for this nodeIndex.
-        val validSignalIndex = 1 + numFeatures * nodeIndex
-        val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
-        if (isSampleValidForNode) {
-          // actual class label
-          val label = arr(0)
-          // Iterate over all features.
-          var featureIndex = 0
-          while (featureIndex < numFeatures) {
-            val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
-            if (isFeatureContinuous) {
-              updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
-            } else {
-              val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-              val isSpaceSufficientForAllCategoricalSplits
-                = numBins > math.pow(2, featureCategories.toInt - 1) - 1
-              if (isSpaceSufficientForAllCategoricalSplits) {
-                updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg,
-                  rightChildShift)
-              } else {
-                updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
-              }
-            }
-            featureIndex += 1
-          }
+    def multiclassWithCategoricalBinSeqOp(
+        agg: Array[Double],
+        treePoint: TreePoint,
+        nodeIndex: Int): Unit = {
+      val label = treePoint.label
+      // Iterate over all features.
+      var featureIndex = 0
+      while (featureIndex < numFeatures) {
+        if (metadata.isUnordered(featureIndex)) {
+          updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift)
+        } else {
+          updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex)
         }
-        nodeIndex += 1
+        featureIndex += 1
       }
     }
 
@@ -818,36 +744,25 @@ object DecisionTree extends Serializable with Logging {
      *
      * @param agg Array storing aggregate calculation, updated by this function.
      *            Size: 3 * numBins * numFeatures * numNodes
-     * @param arr Bin mapping from findBinsForLevel.
-     *             Array of size 1 + (numFeatures * numNodes).
+     * @param treePoint   Data point being aggregated.
+     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
      * @return agg
      */
-    def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = {
-      // Iterate over all nodes.
-      var nodeIndex = 0
-      while (nodeIndex < numNodes) {
-        // Check whether the instance was valid for this nodeIndex.
-        val validSignalIndex = 1 + numFeatures * nodeIndex
-        val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
-        if (isSampleValidForNode) {
-          // actual class label
-          val label = arr(0)
-          // Iterate over all features.
-          var featureIndex = 0
-          while (featureIndex < numFeatures) {
-            // Find the bin index for this feature.
-            val arrShift = 1 + numFeatures * nodeIndex
-            val arrIndex = arrShift + featureIndex
-            // Update count, sum, and sum^2 for one bin.
-            val aggShift = 3 * numBins * numFeatures * nodeIndex
-            val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3
-            agg(aggIndex) = agg(aggIndex) + 1
-            agg(aggIndex + 1) = agg(aggIndex + 1) + label
-            agg(aggIndex + 2) = agg(aggIndex + 2) + label * label
-            featureIndex += 1
-          }
-        }
-        nodeIndex += 1
+    def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = {
+      val label = treePoint.label
+      // Iterate over all features.
+      var featureIndex = 0
+      while (featureIndex < numFeatures) {
+        // Update count, sum, and sum^2 for one bin.
+        val binIndex = treePoint.binnedFeatures(featureIndex)
+        val aggIndex =
+          3 * numBins * numFeatures * nodeIndex +
+          3 * numBins * featureIndex +
+          3 * binIndex
+        agg(aggIndex) += 1
+        agg(aggIndex + 1) += label
+        agg(aggIndex + 2) += label * label
+        featureIndex += 1
       }
     }
 
@@ -866,26 +781,30 @@ object DecisionTree extends Serializable with Logging {
      *              2 * numClasses * numBins * numFeatures * numNodes for unordered features.
      *            Size for regression:
      *              3 * numBins * numFeatures * numNodes.
-     * @param arr  Bin mapping from findBinsForLevel.
-     *             Array of size 1 + (numFeatures * numNodes).
+     * @param treePoint   Data point being aggregated.
      * @return  agg
      */
-    def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = {
-      strategy.algo match {
-        case Classification =>
-          if(isMulticlassClassificationWithCategoricalFeatures) {
-            multiclassWithCategoricalBinSeqOp(arr, agg)
+    def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = {
+      val nodeIndex = treePointToNodeIndex(treePoint)
+      // If the example does not reach this level, then nodeIndex < 0.
+      // If the example reaches this level but is handled in a different group,
+      //  then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group).
+      if (nodeIndex >= 0 && nodeIndex < numNodes) {
+        if (metadata.isClassification) {
+          if (isMulticlassWithCategoricalFeatures) {
+            multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex)
           } else {
-            binaryOrNotCategoricalBinSeqOp(arr, agg)
+            binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex)
           }
-        case Regression => regressionBinSeqOp(arr, agg)
+        } else {
+          regressionBinSeqOp(agg, treePoint, nodeIndex)
+        }
       }
       agg
     }
 
     // Calculate bin aggregate length for classification or regression.
-    val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses,
-        isMulticlassClassificationWithCategoricalFeatures, strategy.algo)
+    val binAggregateLength = numNodes * getElementsPerNode(metadata, numBins)
     logDebug("binAggregateLength = " + binAggregateLength)
 
     /**
@@ -905,144 +824,134 @@ object DecisionTree extends Serializable with Logging {
     }
 
     // Calculate bin aggregates.
+    timer.start("aggregation")
     val binAggregates = {
-      binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
+      input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp)
     }
     timer.stop("aggregation")
     logDebug("binAggregates.length = " + binAggregates.length)
 
     /**
-     * Calculates the information gain for all splits based upon left/right split aggregates.
-     * @param leftNodeAgg left node aggregates
-     * @param featureIndex feature index
-     * @param splitIndex split index
-     * @param rightNodeAgg right node aggregate
+     * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
+     * @param leftNodeAgg left node aggregates for this (feature, split)
+     * @param rightNodeAgg right node aggregate for this (feature, split)
      * @param topImpurity impurity of the parent node
      * @return information gain and statistics for all splits
      */
     def calculateGainForSplit(
-        leftNodeAgg: Array[Array[Array[Double]]],
-        featureIndex: Int,
-        splitIndex: Int,
-        rightNodeAgg: Array[Array[Array[Double]]],
+        leftNodeAgg: Array[Double],
+        rightNodeAgg: Array[Double],
         topImpurity: Double): InformationGainStats = {
-      strategy.algo match {
-        case Classification =>
-          val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex)
-          val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex)
-          val leftTotalCount = leftCounts.sum
-          val rightTotalCount = rightCounts.sum
-
-          val impurity = {
-            if (level > 0) {
-              topImpurity
-            } else {
-              // Calculate impurity for root node.
-              val rootNodeCounts = new Array[Double](numClasses)
-              var classIndex = 0
-              while (classIndex < numClasses) {
-                rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex)
-                classIndex += 1
-              }
-              strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
-            }
-          }
+      if (metadata.isClassification) {
+        val leftTotalCount = leftNodeAgg.sum
+        val rightTotalCount = rightNodeAgg.sum
 
-          val totalCount = leftTotalCount + rightTotalCount
-          if (totalCount == 0) {
-            // Return arbitrary prediction.
-            return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
+        val impurity = {
+          if (level > 0) {
+            topImpurity
+          } else {
+            // Calculate impurity for root node.
+            val rootNodeCounts = new Array[Double](numClasses)
+            var classIndex = 0
+            while (classIndex < numClasses) {
+              rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex)
+              classIndex += 1
+            }
+            metadata.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
           }
+        }
 
-          // Sum of count for each label
-          val leftRightCounts: Array[Double] =
-            leftCounts.zip(rightCounts).map { case (leftCount, rightCount) =>
-              leftCount + rightCount
-            }
+        val totalCount = leftTotalCount + rightTotalCount
+        if (totalCount == 0) {
+          // Return arbitrary prediction.
+          return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
+        }
 
-          def indexOfLargestArrayElement(array: Array[Double]): Int = {
-            val result = array.foldLeft(-1, Double.MinValue, 0) {
-              case ((maxIndex, maxValue, currentIndex), currentValue) =>
-                if (currentValue > maxValue) {
-                  (currentIndex, currentValue, currentIndex + 1)
-                } else {
-                  (maxIndex, maxValue, currentIndex + 1)
-                }
-            }
-            if (result._1 < 0) {
-              throw new RuntimeException("DecisionTree internal error:" +
-                " calculateGainForSplit failed in indexOfLargestArrayElement")
-            }
-            result._1
+        // Sum of count for each label
+        val leftrightNodeAgg: Array[Double] =
+          leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) =>
+            leftCount + rightCount
           }
 
-          val predict = indexOfLargestArrayElement(leftRightCounts)
-          val prob = leftRightCounts(predict) / totalCount
-
-          val leftImpurity = if (leftTotalCount == 0) {
-            topImpurity
-          } else {
-            strategy.impurity.calculate(leftCounts, leftTotalCount)
+        def indexOfLargestArrayElement(array: Array[Double]): Int = {
+          val result = array.foldLeft(-1, Double.MinValue, 0) {
+            case ((maxIndex, maxValue, currentIndex), currentValue) =>
+              if (currentValue > maxValue) {
+                (currentIndex, currentValue, currentIndex + 1)
+              } else {
+                (maxIndex, maxValue, currentIndex + 1)
+              }
           }
-          val rightImpurity = if (rightTotalCount == 0) {
-            topImpurity
-          } else {
-            strategy.impurity.calculate(rightCounts, rightTotalCount)
+          if (result._1 < 0) {
+            throw new RuntimeException("DecisionTree internal error:" +
+              " calculateGainForSplit failed in indexOfLargestArrayElement")
           }
+          result._1
+        }
 
-          val leftWeight = leftTotalCount / totalCount
-          val rightWeight = rightTotalCount / totalCount
+        val predict = indexOfLargestArrayElement(leftrightNodeAgg)
+        val prob = leftrightNodeAgg(predict) / totalCount
 
-          val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+        val leftImpurity = if (leftTotalCount == 0) {
+          topImpurity
+        } else {
+          metadata.impurity.calculate(leftNodeAgg, leftTotalCount)
+        }
+        val rightImpurity = if (rightTotalCount == 0) {
+          topImpurity
+        } else {
+          metadata.impurity.calculate(rightNodeAgg, rightTotalCount)
+        }
 
-          new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+        val leftWeight = leftTotalCount / totalCount
+        val rightWeight = rightTotalCount / totalCount
 
-        case Regression =>
-          val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
-          val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
-          val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2)
+        val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
 
-          val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0)
-          val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1)
-          val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2)
+        new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
 
-          val impurity = {
-            if (level > 0) {
-              topImpurity
-            } else {
-              // Calculate impurity for root node.
-              val count = leftCount + rightCount
-              val sum = leftSum + rightSum
-              val sumSquares = leftSumSquares + rightSumSquares
-              strategy.impurity.calculate(count, sum, sumSquares)
-            }
-          }
+      } else {
+        // Regression
 
-          if (leftCount == 0) {
-            return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,
-              rightSum / rightCount)
-          }
-          if (rightCount == 0) {
-            return new InformationGainStats(0, topImpurity ,topImpurity,
-              Double.MinValue, leftSum / leftCount)
+        val leftCount = leftNodeAgg(0)
+        val leftSum = leftNodeAgg(1)
+        val leftSumSquares = leftNodeAgg(2)
+
+        val rightCount = rightNodeAgg(0)
+        val rightSum = rightNodeAgg(1)
+        val rightSumSquares = rightNodeAgg(2)
+
+        val impurity = {
+          if (level > 0) {
+            topImpurity
+          } else {
+            // Calculate impurity for root node.
+            val count = leftCount + rightCount
+            val sum = leftSum + rightSum
+            val sumSquares = leftSumSquares + rightSumSquares
+            metadata.impurity.calculate(count, sum, sumSquares)
           }
+        }
+
+        if (leftCount == 0) {
+          return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,
+            rightSum / rightCount)
+        }
+        if (rightCount == 0) {
+          return new InformationGainStats(0, topImpurity, topImpurity,
+            Double.MinValue, leftSum / leftCount)
+        }
 
-          val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares)
-          val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares)
+        val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares)
+        val rightImpurity = metadata.impurity.calculate(rightCount, rightSum, rightSumSquares)
 
-          val leftWeight = leftCount.toDouble / (leftCount + rightCount)
-          val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+        val leftWeight = leftCount.toDouble / (leftCount + rightCount)
+        val rightWeight = rightCount.toDouble / (leftCount + rightCount)
 
-          val gain = {
-            if (level > 0) {
-              impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
-            } else {
-              impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
-            }
-          }
+        val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
 
-          val predict = (leftSum + rightSum) / (leftCount + rightCount)
-          new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
+        val predict = (leftSum + rightSum) / (leftCount + rightCount)
+        new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
       }
     }
 
@@ -1065,6 +974,19 @@ object DecisionTree extends Serializable with Logging {
         binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
 
 
+      /**
+       * The input binData is indexed as (feature, bin, class).
+       * This computes cumulative sums over splits.
+       * Each (feature, class) pair is handled separately.
+       * Note: numSplits = numBins - 1.
+       * @param leftNodeAgg  Each (feature, class) slice is an array over splits.
+       *                     Element i (i = 0, ..., numSplits - 2) is set to be
+       *                     the cumulative sum (from left) over binData for bins 0, ..., i.
+       * @param rightNodeAgg Each (feature, class) slice is an array over splits.
+       *                     Element i (i = 1, ..., numSplits - 1) is set to be
+       *                     the cumulative sum (from right) over binData for bins
+       *                     numBins - 1, ..., numBins - 1 - i.
+       */
       def findAggForOrderedFeatureClassification(
           leftNodeAgg: Array[Array[Array[Double]]],
           rightNodeAgg: Array[Array[Array[Double]]],
@@ -1169,45 +1091,32 @@ object DecisionTree extends Serializable with Logging {
         }
       }
 
-      strategy.algo match {
-        case Classification =>
-          // Initialize left and right split aggregates.
-          val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
-          val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
-          var featureIndex = 0
-          while (featureIndex < numFeatures) {
-            if (isMulticlassClassificationWithCategoricalFeatures) {
-              val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
-              if (isFeatureContinuous) {
-                findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
-              } else {
-                val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-                val isSpaceSufficientForAllCategoricalSplits
-                  = numBins > math.pow(2, featureCategories.toInt - 1) - 1
-                if (isSpaceSufficientForAllCategoricalSplits) {
-                  findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
-                } else {
-                  findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
-                }
-              }
-            } else {
-              findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
-            }
-            featureIndex += 1
-          }
-
-          (leftNodeAgg, rightNodeAgg)
-        case Regression =>
-          // Initialize left and right split aggregates.
-          val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
-          val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
-          // Iterate over all features.
-          var featureIndex = 0
-          while (featureIndex < numFeatures) {
-            findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex)
-            featureIndex += 1
+      if (metadata.isClassification) {
+        // Initialize left and right split aggregates.
+        val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
+        val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
+        var featureIndex = 0
+        while (featureIndex < numFeatures) {
+          if (metadata.isUnordered(featureIndex)) {
+            findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
+          } else {
+            findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
           }
-          (leftNodeAgg, rightNodeAgg)
+          featureIndex += 1
+        }
+        (leftNodeAgg, rightNodeAgg)
+      } else {
+        // Regression
+        // Initialize left and right split aggregates.
+        val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
+        val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
+        // Iterate over all features.
+        var featureIndex = 0
+        while (featureIndex < numFeatures) {
+          findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex)
+          featureIndex += 1
+        }
+        (leftNodeAgg, rightNodeAgg)
       }
     }
 
@@ -1225,8 +1134,9 @@ object DecisionTree extends Serializable with Logging {
         val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
         var splitIndex = 0
         while (splitIndex < numSplitsForFeature) {
-          gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
-            splitIndex, rightNodeAgg, nodeImpurity)
+          gains(featureIndex)(splitIndex) =
+            calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex),
+              rightNodeAgg(featureIndex)(splitIndex), nodeImpurity)
           splitIndex += 1
         }
         featureIndex += 1
@@ -1238,18 +1148,14 @@ object DecisionTree extends Serializable with Logging {
      * Get the number of splits for a feature.
      */
     def getNumSplitsForFeature(featureIndex: Int): Int = {
-      val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
-      if (isFeatureContinuous) {
+      if (metadata.isContinuous(featureIndex)) {
         numBins - 1
       } else {
         // Categorical feature
-        val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-        val isSpaceSufficientForAllCategoricalSplits =
-          numBins > math.pow(2, featureCategories.toInt - 1) - 1
-        if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
-          math.pow(2.0, featureCategories - 1).toInt - 1
+        val featureCategories = metadata.featureArity(featureIndex)
+        if (metadata.isUnordered(featureIndex)) {
+          (1 << featureCategories - 1) - 1
         } else {
-          // Ordered features
           featureCategories
         }
       }
@@ -1308,29 +1214,29 @@ object DecisionTree extends Serializable with Logging {
      * Get bin data for one node.
      */
     def getBinDataForNode(node: Int): Array[Double] = {
-      strategy.algo match {
-        case Classification =>
-          if (isMulticlassClassificationWithCategoricalFeatures) {
-            val shift = numClasses * node * numBins * numFeatures
-            val rightChildShift = numClasses * numBins * numFeatures * numNodes
-            val binsForNode = {
-              val leftChildData
-                = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
-              val rightChildData
-              = binAggregates.slice(rightChildShift + shift,
-                rightChildShift + shift + numClasses * numBins * numFeatures)
-              leftChildData ++ rightChildData
-            }
-            binsForNode
-          } else {
-            val shift = numClasses * node * numBins * numFeatures
-            val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
-            binsForNode
+      if (metadata.isClassification) {
+        if (isMulticlassWithCategoricalFeatures) {
+          val shift = numClasses * node * numBins * numFeatures
+          val rightChildShift = numClasses * numBins * numFeatures * numNodes
+          val binsForNode = {
+            val leftChildData
+            = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
+            val rightChildData
+            = binAggregates.slice(rightChildShift + shift,
+              rightChildShift + shift + numClasses * numBins * numFeatures)
+            leftChildData ++ rightChildData
           }
-        case Regression =>
-          val shift = 3 * node * numBins * numFeatures
-          val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
           binsForNode
+        } else {
+          val shift = numClasses * node * numBins * numFeatures
+          val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
+          binsForNode
+        }
+      } else {
+        // Regression
+        val shift = 3 * node * numBins * numFeatures
+        val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
+        binsForNode
       }
     }
 
@@ -1340,7 +1246,7 @@ object DecisionTree extends Serializable with Logging {
     // Iterating over all nodes at this level
     var node = 0
     while (node < numNodes) {
-      val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift
+      val nodeImpurityIndex = (1 << level) - 1 + node + groupShift
       val binsForNode: Array[Double] = getBinDataForNode(node)
       logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
       val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
@@ -1358,20 +1264,15 @@ object DecisionTree extends Serializable with Logging {
    *
    * @param numBins  Number of bins = 1 + number of possible splits.
    */
-  private def getElementsPerNode(
-      numFeatures: Int,
-      numBins: Int,
-      numClasses: Int,
-      isMulticlassClassificationWithCategoricalFeatures: Boolean,
-      algo: Algo): Int = {
-    algo match {
-      case Classification =>
-        if (isMulticlassClassificationWithCategoricalFeatures) {
-          2 * numClasses * numBins * numFeatures
-        } else {
-          numClasses * numBins * numFeatures
-        }
-      case Regression => 3 * numBins * numFeatures
+  private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = {
+    if (metadata.isClassification) {
+      if (metadata.isMulticlassWithCategoricalFeatures) {
+        2 * metadata.numClasses * numBins * metadata.numFeatures
+      } else {
+        metadata.numClasses * numBins * metadata.numFeatures
+      }
+    } else {
+      3 * numBins * metadata.numFeatures
     }
   }
 
@@ -1390,16 +1291,15 @@ object DecisionTree extends Serializable with Logging {
    *       For multiclass classification with a low-arity feature
    *       (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
    *       the feature is split based on subsets of categories.
-   *       There are math.pow(2, maxFeatureValue - 1) - 1 splits.
+   *       There are (1 << maxFeatureValue - 1) - 1 splits.
    *   (b) "ordered features"
    *       For regression and binary classification,
    *       and for multiclass classification with a high-arity feature,
    *       there is one bin per category.
    *
    * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
-   * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
-   *                 parameters for construction the DecisionTree
-   * @return A tuple of (splits,bins).
+   * @param metadata Learning and dataset metadata
+   * @return A tuple of (splits, bins).
    *         Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
    *          of size (numFeatures, numBins - 1).
    *         Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
@@ -1407,19 +1307,18 @@ object DecisionTree extends Serializable with Logging {
    */
   protected[tree] def findSplitsBins(
       input: RDD[LabeledPoint],
-      strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
+      metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
 
     val count = input.count()
 
     // Find the number of features by looking at the first sample
     val numFeatures = input.take(1)(0).features.size
 
-    val maxBins = strategy.maxBins
+    val maxBins = metadata.maxBins
     val numBins = if (maxBins <= count) maxBins else count.toInt
     logDebug("numBins = " + numBins)
-    val isMulticlassClassification = strategy.isMulticlassClassification
-    logDebug("isMulticlassClassification = " + isMulticlassClassification)
-
+    val isMulticlass = metadata.isMulticlass
+    logDebug("isMulticlass = " + isMulticlass)
 
     /*
      * Ensure numBins is always greater than the categories. For multiclass classification,
@@ -1431,13 +1330,12 @@ object DecisionTree extends Serializable with Logging {
      * by the number of training examples.
      * TODO: Allow this case, where we simply will know nothing about some categories.
      */
-    if (strategy.categoricalFeaturesInfo.size > 0) {
-      val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
+    if (metadata.featureArity.size > 0) {
+      val maxCategoriesForFeatures = metadata.featureArity.maxBy(_._2)._2
       require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " +
         "in categorical features")
     }
 
-
     // Calculate the number of sample for approximate quantile calculation.
     val requiredSamples = numBins*numBins
     val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
@@ -1451,7 +1349,7 @@ object DecisionTree extends Serializable with Logging {
     val stride: Double = numSamples.toDouble / numBins
     logDebug("stride = " + stride)
 
-    strategy.quantileCalculationStrategy match {
+    metadata.quantileStrategy match {
       case Sort =>
         val splits = Array.ofDim[Split](numFeatures, numBins - 1)
         val bins = Array.ofDim[Bin](numFeatures, numBins)
@@ -1462,7 +1360,7 @@ object DecisionTree extends Serializable with Logging {
         var featureIndex = 0
         while (featureIndex < numFeatures) {
           // Check whether the feature is continuous.
-          val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+          val isFeatureContinuous = metadata.isContinuous(featureIndex)
           if (isFeatureContinuous) {
             val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
             val stride: Double = numSamples.toDouble / numBins
@@ -1475,18 +1373,14 @@ object DecisionTree extends Serializable with Logging {
               splits(featureIndex)(index) = split
             }
           } else { // Categorical feature
-            val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-            val isSpaceSufficientForAllCategoricalSplits
-              = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+            val featureCategories = metadata.featureArity(featureIndex)
 
             // Use different bin/split calculation strategy for categorical features in multiclass
             // classification that satisfy the space constraint.
-            val isUnorderedFeature =
-              isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
-            if (isUnorderedFeature) {
+            if (metadata.isUnordered(featureIndex)) {
               // 2^(maxFeatureValue- 1) - 1 combinations
               var index = 0
-              while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {
+              while (index < (1 << featureCategories - 1) - 1) {
                 val categories: List[Double]
                   = extractMultiClassCategories(index + 1, featureCategories)
                 splits(featureIndex)(index)
@@ -1516,7 +1410,7 @@ object DecisionTree extends Serializable with Logging {
                * centroidForCategories is a mapping: category (for the given feature) --> centroid
                */
               val centroidForCategories = {
-                if (isMulticlassClassification) {
+                if (isMulticlass) {
                   // For categorical variables in multiclass classification,
                   // each bin is a category. The bins are sorted and they
                   // are ordered by calculating the impurity of their corresponding labels.
@@ -1524,7 +1418,7 @@ object DecisionTree extends Serializable with Logging {
                    .groupBy(_._1)
                    .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble))
                    .map(x => (x._1, x._2.values.toArray))
-                   .map(x => (x._1, strategy.impurity.calculate(x._2, x._2.sum)))
+                   .map(x => (x._1, metadata.impurity.calculate(x._2, x._2.sum)))
                 } else { // regression or binary classification
                   // For categorical variables in regression and binary classification,
                   // each bin is a category. The bins are sorted and they
@@ -1576,7 +1470,7 @@ object DecisionTree extends Serializable with Logging {
         // Find all bins.
         featureIndex = 0
         while (featureIndex < numFeatures) {
-          val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+          val isFeatureContinuous = metadata.isContinuous(featureIndex)
           if (isFeatureContinuous) { // Bins for categorical variables are already assigned.
             bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
               splits(featureIndex)(0), Continuous, Double.MinValue)
@@ -1590,7 +1484,7 @@ object DecisionTree extends Serializable with Logging {
           }
           featureIndex += 1
         }
-        (splits,bins)
+        (splits, bins)
       case MinMax =>
         throw new UnsupportedOperationException("minmax not supported yet.")
       case ApproxHist =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
new file mode 100644
index 0000000000000000000000000000000000000000..d9eda354dc986c66f95343a2c94bd2842618f1e1
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import scala.collection.mutable
+
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impurity.Impurity
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * Learning and dataset metadata for DecisionTree.
+ *
+ * @param numClasses    For classification: labels can take values {0, ..., numClasses - 1}.
+ *                      For regression: fixed at 0 (no meaning).
+ * @param featureArity  Map: categorical feature index --> arity.
+ *                      I.e., the feature takes values in {0, ..., arity - 1}.
+ */
+private[tree] class DecisionTreeMetadata(
+    val numFeatures: Int,
+    val numExamples: Long,
+    val numClasses: Int,
+    val maxBins: Int,
+    val featureArity: Map[Int, Int],
+    val unorderedFeatures: Set[Int],
+    val impurity: Impurity,
+    val quantileStrategy: QuantileStrategy) extends Serializable {
+
+  def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
+
+  def isClassification: Boolean = numClasses >= 2
+
+  def isMulticlass: Boolean = numClasses > 2
+
+  def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0)
+
+  def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex)
+
+  def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
+
+}
+
+private[tree] object DecisionTreeMetadata {
+
+  def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {
+
+    val numFeatures = input.take(1)(0).features.size
+    val numExamples = input.count()
+    val numClasses = strategy.algo match {
+      case Classification => strategy.numClassesForClassification
+      case Regression => 0
+    }
+
+    val maxBins = math.min(strategy.maxBins, numExamples).toInt
+    val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0)
+
+    val unorderedFeatures = new mutable.HashSet[Int]()
+    if (numClasses > 2) {
+      strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
+        if (k - 1 < log2MaxBinsp1) {
+          // Note: The above check is equivalent to checking:
+          //       numUnorderedBins = (1 << k - 1) - 1 < maxBins
+          unorderedFeatures.add(f)
+        } else {
+          // TODO: Allow this case, where we simply will know nothing about some categories?
+          require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
+            s"in categorical features (>= $k)")
+        }
+      }
+    } else {
+      strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
+        require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
+          s"in categorical features (>= $k)")
+      }
+    }
+
+    new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
+      strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
+      strategy.impurity, strategy.quantileCalculationStrategy)
+  }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
index ccac1031fd9d91149638a1840c36d09c545e9d53..170e43e2220833242a3d41b66eba40563efed79e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.mllib.tree.impl
 
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.model.Bin
 import org.apache.spark.rdd.RDD
 
@@ -48,50 +47,35 @@ private[tree] object TreePoint {
    * Convert an input dataset into its TreePoint representation,
    * binning feature values in preparation for DecisionTree training.
    * @param input     Input dataset.
-   * @param strategy  DecisionTree training info, used for dataset metadata.
    * @param bins      Bins for features, of size (numFeatures, numBins).
+   * @param metadata Learning and dataset metadata
    * @return  TreePoint dataset representation
    */
   def convertToTreeRDD(
       input: RDD[LabeledPoint],
-      strategy: Strategy,
-      bins: Array[Array[Bin]]): RDD[TreePoint] = {
+      bins: Array[Array[Bin]],
+      metadata: DecisionTreeMetadata): RDD[TreePoint] = {
     input.map { x =>
-      TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins,
-        strategy.categoricalFeaturesInfo)
+      TreePoint.labeledPointToTreePoint(x, bins, metadata)
     }
   }
 
   /**
    * Convert one LabeledPoint into its TreePoint representation.
    * @param bins      Bins for features, of size (numFeatures, numBins).
-   * @param categoricalFeaturesInfo  Map over categorical features: feature index --> feature arity
    */
   private def labeledPointToTreePoint(
       labeledPoint: LabeledPoint,
-      isMulticlassClassification: Boolean,
       bins: Array[Array[Bin]],
-      categoricalFeaturesInfo: Map[Int, Int]): TreePoint = {
+      metadata: DecisionTreeMetadata): TreePoint = {
 
     val numFeatures = labeledPoint.features.size
     val numBins = bins(0).size
     val arr = new Array[Int](numFeatures)
     var featureIndex = 0
     while (featureIndex < numFeatures) {
-      val featureInfo = categoricalFeaturesInfo.get(featureIndex)
-      val isFeatureContinuous = featureInfo.isEmpty
-      if (isFeatureContinuous) {
-        arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false,
-          bins, categoricalFeaturesInfo)
-      } else {
-        val featureCategories = featureInfo.get
-        val isSpaceSufficientForAllCategoricalSplits
-          = numBins > math.pow(2, featureCategories.toInt - 1) - 1
-        val isUnorderedFeature =
-          isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
-        arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous,
-          isUnorderedFeature, bins, categoricalFeaturesInfo)
-      }
+      arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex),
+        metadata.isUnordered(featureIndex), bins, metadata.featureArity)
       featureIndex += 1
     }
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
index c89c1e371a40e16008e07ffbc8cc7d11a16ca2d6..af35d88f713e50130df681dbaf9c81b7b0035d72 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
@@ -20,15 +20,25 @@ package org.apache.spark.mllib.tree.model
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 
 /**
- * Used for "binning" the features bins for faster best split calculation. For a continuous
- * feature, a bin is determined by a low and a high "split". For a categorical feature,
- * the a bin is determined using a single label value (category).
+ * Used for "binning" the features bins for faster best split calculation.
+ *
+ * For a continuous feature, the bin is determined by a low and a high split,
+ *  where an example with featureValue falls into the bin s.t.
+ *  lowSplit.threshold < featureValue <= highSplit.threshold.
+ *
+ * For ordered categorical features, there is a 1-1-1 correspondence between
+ *  bins, splits, and feature values.  The bin is determined by category/feature value.
+ *  However, the bins are not necessarily ordered by feature value;
+ *  they are ordered using impurity.
+ * For unordered categorical features, there is a 1-1 correspondence between bins, splits,
+ *  where bins and splits correspond to subsets of feature values (in highSplit.categories).
+ *
  * @param lowSplit signifying the lower threshold for the continuous feature to be
  *                 accepted in the bin
  * @param highSplit signifying the upper threshold for the continuous feature to be
  *                 accepted in the bin
  * @param featureType type of feature -- categorical or continuous
- * @param category categorical label value accepted in the bin for binary classification
+ * @param category categorical label value accepted in the bin for ordered features
  */
 private[tree]
 case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 3d3406b5d5f22ab216e1d66bb524d50b456a3364..0594fd0749d21d1d32ad7a0552341044b6931768 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -39,7 +39,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
    * @return Double prediction from the trained model
    */
   def predict(features: Vector): Double = {
-    topNode.predictIfLeaf(features)
+    topNode.predict(features)
   }
 
   /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
deleted file mode 100644
index 2deaf4ae8dcaba9363b4d0b1b4168338389b3fd3..0000000000000000000000000000000000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-/**
- * Filter specifying a split and type of comparison to be applied on features
- * @param split split specifying the feature index, type and threshold
- * @param comparison integer specifying <,=,>
- */
-private[tree] case class Filter(split: Split, comparison: Int) {
-  // Comparison -1,0,1 signifies <.=,>
-  override def toString = " split = " + split + "comparison = " + comparison
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 944f11c2c2e4f157e363c60f1a315e285d46ad27..0eee6262781c1d0e6fad7fe5e2a22d91c64c1162 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -69,24 +69,24 @@ class Node (
 
   /**
    * predict value if node is not leaf
-   * @param feature feature value
+   * @param features feature value
    * @return predicted value
    */
-  def predictIfLeaf(feature: Vector) : Double = {
+  def predict(features: Vector) : Double = {
     if (isLeaf) {
       predict
     } else{
       if (split.get.featureType == Continuous) {
-        if (feature(split.get.feature) <= split.get.threshold) {
-          leftNode.get.predictIfLeaf(feature)
+        if (features(split.get.feature) <= split.get.threshold) {
+          leftNode.get.predict(features)
         } else {
-          rightNode.get.predictIfLeaf(feature)
+          rightNode.get.predict(features)
         }
       } else {
-        if (split.get.categories.contains(feature(split.get.feature))) {
-          leftNode.get.predictIfLeaf(feature)
+        if (split.get.categories.contains(features(split.get.feature))) {
+          leftNode.get.predict(features)
         } else {
-          rightNode.get.predictIfLeaf(feature)
+          rightNode.get.predict(features)
         }
       }
     }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
index d7ffd386c05eeb749c703d46f2aea6de264ec200..50fb48b40de3d1eba2b21d580afafeee64fd4fd2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -24,9 +24,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
  * :: DeveloperApi ::
  * Split applied to a feature
  * @param feature feature index
- * @param threshold threshold for continuous feature
+ * @param threshold Threshold for continuous feature.
+ *                  Split left if feature <= threshold, else right.
  * @param featureType type of feature -- categorical or continuous
- * @param categories accepted values for categorical variables
+ * @param categories Split left if categorical feature value is in this set, else right.
  */
 @DeveloperApi
 case class Split(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index a5c49a38dc08f6468e6e54b01413e35ee1e6eaf7..2f36fd907772cd373beba3b7c2980b35d8afe4c3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -23,10 +23,10 @@ import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
-import org.apache.spark.mllib.tree.impl.TreePoint
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.LocalSparkContext
 import org.apache.spark.mllib.regression.LabeledPoint
@@ -64,7 +64,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Gini, 3, 2, 100)
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(bins.length === 2)
     assert(splits(0).length === 99)
@@ -82,7 +83,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(bins.length === 2)
     assert(splits(0).length === 99)
@@ -162,7 +164,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
 
     // Check splits.
 
@@ -279,7 +282,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 100,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
 
     // Expecting 2^2 - 1 = 3 bins/splits
     assert(splits(0)(0).feature === 0)
@@ -373,7 +377,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 100,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
 
     // 2^10 - 1 > 100, so categorical variables will be ordered
 
@@ -428,10 +433,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       maxDepth = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
 
     val split = bestSplits(0)._1
     assert(split.categories.length === 1)
@@ -456,10 +462,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       maxDepth = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
 
     val split = bestSplits(0)._1
     assert(split.categories.length === 1)
@@ -495,7 +502,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Gini, 3, 2, 100)
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
     assert(bins.length === 2)
@@ -503,9 +511,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
     assert(bestSplits(0)._2.gain === 0)
@@ -518,7 +526,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Gini, 3, 2, 100)
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
     assert(bins.length === 2)
@@ -526,9 +535,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
     assert(bestSplits(0)._2.gain === 0)
@@ -542,7 +551,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
     assert(bins.length === 2)
@@ -550,9 +560,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
     assert(bestSplits(0)._2.gain === 0)
@@ -566,7 +576,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
     assert(bins.length === 2)
@@ -574,9 +585,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
     assert(bestSplits(0)._2.gain === 0)
@@ -590,7 +601,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
     assert(bins.length === 2)
@@ -598,14 +610,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins(0).length === 100)
 
-    val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1)
-    val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1)
-    val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter))
+    // Train a 1-node model
+    val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
+    val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
+    val nodes: Array[Node] = new Array[Node](7)
+    nodes(0) = modelOneNode.topNode
+    nodes(0).leftNode = None
+    nodes(0).rightNode = None
+
     val parentImpurities = Array(0.5, 0.5, 0.5)
 
     // Single group second level tree construction.
-    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters,
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes,
       splits, bins, 10)
     assert(bestSplits.length === 2)
     assert(bestSplits(0)._2.gain > 0)
@@ -613,8 +630,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
     // level tree construction.
-    val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1,
-      filters, splits, bins, 0)
+    val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1,
+      nodes, splits, bins, 0)
     assert(bestSplitsWithGroups.length === 2)
     assert(bestSplitsWithGroups(0)._2.gain > 0)
     assert(bestSplitsWithGroups(1)._2.gain > 0)
@@ -629,19 +646,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
       assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
     }
-
   }
 
   test("stump with categorical variables for multiclass classification") {
     val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
       numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
     assert(strategy.isMulticlassClassification)
-    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
     val bestSplit = bestSplits(0)._1
@@ -657,11 +674,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
     arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0))
     arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0))
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
       numClassesForClassification = 2)
 
-    val model = DecisionTree.train(input, strategy)
+    val model = DecisionTree.train(rdd, strategy)
     validateClassifier(model, arr, 1.0)
     assert(model.numNodes === 3)
     assert(model.depth === 1)
@@ -688,20 +705,22 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
   test("stump with categorical variables for multiclass classification, with just enough bins") {
     val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features
     val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
-      numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+      numClassesForClassification = 3, maxBins = maxBins,
+      categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
     assert(strategy.isMulticlassClassification)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
 
-    val model = DecisionTree.train(input, strategy)
+    val model = DecisionTree.train(rdd, strategy)
     validateClassifier(model, arr, 1.0)
     assert(model.numNodes === 3)
     assert(model.depth === 1)
 
-    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
     val bestSplit = bestSplits(0)._1
@@ -716,18 +735,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
   test("stump with continuous variables for multiclass classification") {
     val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
       numClassesForClassification = 3)
     assert(strategy.isMulticlassClassification)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
 
-    val model = DecisionTree.train(input, strategy)
+    val model = DecisionTree.train(rdd, strategy)
     validateClassifier(model, arr, 0.9)
 
-    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
     val bestSplit = bestSplits(0)._1
@@ -741,18 +761,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
   test("stump with continuous + categorical variables for multiclass classification") {
     val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
       numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
     assert(strategy.isMulticlassClassification)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
 
-    val model = DecisionTree.train(input, strategy)
+    val model = DecisionTree.train(rdd, strategy)
     validateClassifier(model, arr, 0.9)
 
-    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
     val bestSplit = bestSplits(0)._1
@@ -765,14 +786,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
   test("stump with categorical variables for ordered multiclass classification") {
     val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
       numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
     assert(strategy.isMulticlassClassification)
-    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
     val bestSplit = bestSplits(0)._1