diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b311d10023894ffa977ab921329991088779dc7a..03eeaa707715bd266961fad8ae95d8a0db7447cf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -532,6 +532,14 @@ object DecisionTree extends Serializable with Logging {
       Some(mutableNodeToFeatures.toMap)
     }
 
+    // array of nodes to train indexed by node index in group
+    val nodes = new Array[Node](numNodes)
+    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+      nodesForTree.foreach { node =>
+        nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
+      }
+    }
+
     // Calculate best splits for all nodes in the group
     timer.start("chooseSplits")
 
@@ -568,7 +576,7 @@ object DecisionTree extends Serializable with Logging {
 
           // find best split for each node
           val (split: Split, stats: InformationGainStats, predict: Predict) =
-            binsToBestSplit(aggStats, splits, featuresForNode)
+            binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
           (nodeIndex, (split, stats, predict))
         }.collectAsMap()
 
@@ -587,17 +595,30 @@ object DecisionTree extends Serializable with Logging {
         // Extract info for this node.  Create children if not leaf.
         val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth)
         assert(node.id == nodeIndex)
-        node.predict = predict.predict
+        node.predict = predict
         node.isLeaf = isLeaf
         node.stats = Some(stats)
+        node.impurity = stats.impurity
         logDebug("Node = " + node)
 
         if (!isLeaf) {
           node.split = Some(split)
-          node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex)))
-          node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex)))
-          nodeQueue.enqueue((treeIndex, node.leftNode.get))
-          nodeQueue.enqueue((treeIndex, node.rightNode.get))
+          val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
+          val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
+          val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
+          node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex),
+            stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
+          node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
+            stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
+
+          // enqueue left child and right child if they are not leaves
+          if (!leftChildIsLeaf) {
+            nodeQueue.enqueue((treeIndex, node.leftNode.get))
+          }
+          if (!rightChildIsLeaf) {
+            nodeQueue.enqueue((treeIndex, node.rightNode.get))
+          }
+
           logDebug("leftChildIndex = " + node.leftNode.get.id +
             ", impurity = " + stats.leftImpurity)
           logDebug("rightChildIndex = " + node.rightNode.get.id +
@@ -617,7 +638,8 @@ object DecisionTree extends Serializable with Logging {
   private def calculateGainForSplit(
       leftImpurityCalculator: ImpurityCalculator,
       rightImpurityCalculator: ImpurityCalculator,
-      metadata: DecisionTreeMetadata): InformationGainStats = {
+      metadata: DecisionTreeMetadata,
+      impurity: Double): InformationGainStats = {
     val leftCount = leftImpurityCalculator.count
     val rightCount = rightImpurityCalculator.count
 
@@ -630,11 +652,6 @@ object DecisionTree extends Serializable with Logging {
 
     val totalCount = leftCount + rightCount
 
-    val parentNodeAgg = leftImpurityCalculator.copy
-    parentNodeAgg.add(rightImpurityCalculator)
-
-    val impurity = parentNodeAgg.calculate()
-
     val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
     val rightImpurity = rightImpurityCalculator.calculate()
 
@@ -649,7 +666,18 @@ object DecisionTree extends Serializable with Logging {
       return InformationGainStats.invalidInformationGainStats
     }
 
-    new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
+    // calculate left and right predict
+    val leftPredict = calculatePredict(leftImpurityCalculator)
+    val rightPredict = calculatePredict(rightImpurityCalculator)
+
+    new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
+      leftPredict, rightPredict)
+  }
+
+  private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
+    val predict = impurityCalculator.predict
+    val prob = impurityCalculator.prob(predict)
+    new Predict(predict, prob)
   }
 
   /**
@@ -657,17 +685,17 @@ object DecisionTree extends Serializable with Logging {
    * Note that this function is called only once for each node.
    * @param leftImpurityCalculator left node aggregates for a split
    * @param rightImpurityCalculator right node aggregates for a split
-   * @return predict value for current node
+   * @return predict value and impurity for current node
    */
-  private def calculatePredict(
+  private def calculatePredictImpurity(
       leftImpurityCalculator: ImpurityCalculator,
-      rightImpurityCalculator: ImpurityCalculator): Predict =  {
+      rightImpurityCalculator: ImpurityCalculator): (Predict, Double) =  {
     val parentNodeAgg = leftImpurityCalculator.copy
     parentNodeAgg.add(rightImpurityCalculator)
-    val predict = parentNodeAgg.predict
-    val prob = parentNodeAgg.prob(predict)
+    val predict = calculatePredict(parentNodeAgg)
+    val impurity = parentNodeAgg.calculate()
 
-    new Predict(predict, prob)
+    (predict, impurity)
   }
 
   /**
@@ -678,10 +706,16 @@ object DecisionTree extends Serializable with Logging {
   private def binsToBestSplit(
       binAggregates: DTStatsAggregator,
       splits: Array[Array[Split]],
-      featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = {
+      featuresForNode: Option[Array[Int]],
+      node: Node): (Split, InformationGainStats, Predict) = {
 
-    // calculate predict only once
-    var predict: Option[Predict] = None
+    // calculate predict and impurity if current node is top node
+    val level = Node.indexToLevel(node.id)
+    var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
+      None
+    } else {
+      Some((node.predict, node.impurity))
+    }
 
     // For each (feature, split), calculate the gain, and select the best (feature, split).
     val (bestSplit, bestSplitStats) =
@@ -708,9 +742,10 @@ object DecisionTree extends Serializable with Logging {
             val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
             val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
             rightChildStats.subtract(leftChildStats)
-            predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
+            predictWithImpurity = Some(predictWithImpurity.getOrElse(
+              calculatePredictImpurity(leftChildStats, rightChildStats)))
             val gainStats = calculateGainForSplit(leftChildStats,
-              rightChildStats, binAggregates.metadata)
+              rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
             (splitIdx, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -722,9 +757,10 @@ object DecisionTree extends Serializable with Logging {
           Range(0, numSplits).map { splitIndex =>
             val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
             val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
-            predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
+            predictWithImpurity = Some(predictWithImpurity.getOrElse(
+              calculatePredictImpurity(leftChildStats, rightChildStats)))
             val gainStats = calculateGainForSplit(leftChildStats,
-              rightChildStats, binAggregates.metadata)
+              rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -794,9 +830,10 @@ object DecisionTree extends Serializable with Logging {
             val rightChildStats =
               binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
             rightChildStats.subtract(leftChildStats)
-            predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
+            predictWithImpurity = Some(predictWithImpurity.getOrElse(
+              calculatePredictImpurity(leftChildStats, rightChildStats)))
             val gainStats = calculateGainForSplit(leftChildStats,
-              rightChildStats, binAggregates.metadata)
+              rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         val categoriesForSplit =
@@ -807,9 +844,7 @@ object DecisionTree extends Serializable with Logging {
       }
     }.maxBy(_._2.gain)
 
-    assert(predict.isDefined, "must calculate predict for each node")
-
-    (bestSplit, bestSplitStats, predict.get)
+    (bestSplit, bestSplitStats, predictWithImpurity.get._1)
   }
 
   /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index a89e71e115806f8f29dc7952a561721fba260b38..9a50ecb550c38500d857466f689d1cf2c8caaeb7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -26,13 +26,17 @@ import org.apache.spark.annotation.DeveloperApi
  * @param impurity current node impurity
  * @param leftImpurity left node impurity
  * @param rightImpurity right node impurity
+ * @param leftPredict left node predict
+ * @param rightPredict right node predict
  */
 @DeveloperApi
 class InformationGainStats(
     val gain: Double,
     val impurity: Double,
     val leftImpurity: Double,
-    val rightImpurity: Double) extends Serializable {
+    val rightImpurity: Double,
+    val leftPredict: Predict,
+    val rightPredict: Predict) extends Serializable {
 
   override def toString = {
     "gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
@@ -58,5 +62,6 @@ private[tree] object InformationGainStats {
    * denote that current split doesn't satisfies minimum info gain or
    * minimum number of instances per node.
    */
-  val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
+  val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
+    new Predict(0.0, 0.0), new Predict(0.0, 0.0))
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 56c3e25d9285fb4bda1ece14eddebdddbd674f45..2179da8dbe03e8949e43a95bcb7a25e46cf21d86 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -32,7 +32,8 @@ import org.apache.spark.mllib.linalg.Vector
  *
  * @param id integer node id, from 1
  * @param predict predicted value at the node
- * @param isLeaf whether the leaf is a node
+ * @param impurity current node impurity
+ * @param isLeaf whether the node is a leaf
  * @param split split to calculate left and right nodes
  * @param leftNode  left child
  * @param rightNode right child
@@ -41,7 +42,8 @@ import org.apache.spark.mllib.linalg.Vector
 @DeveloperApi
 class Node (
     val id: Int,
-    var predict: Double,
+    var predict: Predict,
+    var impurity: Double,
     var isLeaf: Boolean,
     var split: Option[Split],
     var leftNode: Option[Node],
@@ -49,7 +51,7 @@ class Node (
     var stats: Option[InformationGainStats]) extends Serializable with Logging {
 
   override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
-    "split = " + split + ", stats = " + stats
+    "impurity =  " + impurity + "split = " + split + ", stats = " + stats
 
   /**
    * build the left node and right nodes if not leaf
@@ -62,6 +64,7 @@ class Node (
     logDebug("id = " + id + ", split = " + split)
     logDebug("stats = " + stats)
     logDebug("predict = " + predict)
+    logDebug("impurity = " + impurity)
     if (!isLeaf) {
       leftNode = Some(nodes(Node.leftChildIndex(id)))
       rightNode = Some(nodes(Node.rightChildIndex(id)))
@@ -77,7 +80,7 @@ class Node (
    */
   def predict(features: Vector) : Double = {
     if (isLeaf) {
-      predict
+      predict.predict
     } else{
       if (split.get.featureType == Continuous) {
         if (features(split.get.feature) <= split.get.threshold) {
@@ -109,7 +112,7 @@ class Node (
     } else {
       Some(rightNode.get.deepCopy())
     }
-    new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
+    new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
   }
 
   /**
@@ -154,7 +157,7 @@ class Node (
     }
     val prefix: String = " " * indentFactor
     if (isLeaf) {
-      prefix + s"Predict: $predict\n"
+      prefix + s"Predict: ${predict.predict}\n"
     } else {
       prefix + s"If ${splitToString(split.get, left=true)}\n" +
         leftNode.get.subtreeToString(indentFactor + 1) +
@@ -170,7 +173,27 @@ private[tree] object Node {
   /**
    * Return a node with the given node id (but nothing else set).
    */
-  def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None)
+  def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0,
+    false, None, None, None, None)
+
+  /**
+   * Construct a node with nodeIndex, predict, impurity and isLeaf parameters.
+   * This is used in `DecisionTree.findBestSplits` to construct child nodes
+   * after finding the best splits for parent nodes.
+   * Other fields are set at next level.
+   * @param nodeIndex integer node id, from 1
+   * @param predict predicted value at the node
+   * @param impurity current node impurity
+   * @param isLeaf whether the node is a leaf
+   * @return new node instance
+   */
+  def apply(
+      nodeIndex: Int,
+      predict: Predict,
+      impurity: Double,
+      isLeaf: Boolean): Node = {
+    new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None)
+  }
 
   /**
    * Return the index of the left child of this node.
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index a48ed71a1c5fcdcc988f54c0bee11343883ac12f..98a72b0c4d750f7f504ea2dbeeddf5a61dee256d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -253,7 +253,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val stats = rootNode.stats.get
     assert(stats.gain > 0)
-    assert(rootNode.predict === 1)
+    assert(rootNode.predict.predict === 1)
     assert(stats.impurity > 0.2)
   }
 
@@ -282,7 +282,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val stats = rootNode.stats.get
     assert(stats.gain > 0)
-    assert(rootNode.predict === 0.6)
+    assert(rootNode.predict.predict === 0.6)
     assert(stats.impurity > 0.2)
   }
 
@@ -352,7 +352,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(stats.gain === 0)
     assert(stats.leftImpurity === 0)
     assert(stats.rightImpurity === 0)
-    assert(rootNode.predict === 1)
+    assert(rootNode.predict.predict === 1)
   }
 
   test("Binary classification stump with fixed label 0 for Entropy") {
@@ -377,7 +377,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(stats.gain === 0)
     assert(stats.leftImpurity === 0)
     assert(stats.rightImpurity === 0)
-    assert(rootNode.predict === 0)
+    assert(rootNode.predict.predict === 0)
   }
 
   test("Binary classification stump with fixed label 1 for Entropy") {
@@ -402,7 +402,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(stats.gain === 0)
     assert(stats.leftImpurity === 0)
     assert(stats.rightImpurity === 0)
-    assert(rootNode.predict === 1)
+    assert(rootNode.predict.predict === 1)
   }
 
   test("Second level node building with vs. without groups") {
@@ -471,7 +471,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       assert(stats1.impurity === stats2.impurity)
       assert(stats1.leftImpurity === stats2.leftImpurity)
       assert(stats1.rightImpurity === stats2.rightImpurity)
-      assert(children1(i).predict === children2(i).predict)
+      assert(children1(i).predict.predict === children2(i).predict.predict)
     }
   }
 
@@ -646,7 +646,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val model = DecisionTree.train(rdd, strategy)
     assert(model.topNode.isLeaf)
-    assert(model.topNode.predict == 0.0)
+    assert(model.topNode.predict.predict == 0.0)
     val predicts = rdd.map(p => model.predict(p.features)).collect()
     predicts.foreach { predict =>
       assert(predict == 0.0)
@@ -693,7 +693,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val model = DecisionTree.train(input, strategy)
     assert(model.topNode.isLeaf)
-    assert(model.topNode.predict == 0.0)
+    assert(model.topNode.predict.predict == 0.0)
     val predicts = input.map(p => model.predict(p.features)).collect()
     predicts.foreach { predict =>
       assert(predict == 0.0)
@@ -705,6 +705,92 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val gain = rootNode.stats.get
     assert(gain == InformationGainStats.invalidInformationGainStats)
   }
+
+  test("Avoid aggregation on the last level") {
+    val arr = new Array[LabeledPoint](4)
+    arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
+    arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
+    arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
+    arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
+    val input = sc.parallelize(arr)
+
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
+      numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3))
+    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+    val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+
+    val topNode = Node.emptyNode(nodeIndex = 1)
+    assert(topNode.predict.predict === Double.MinValue)
+    assert(topNode.impurity === -1.0)
+    assert(topNode.isLeaf === false)
+
+    val nodesForGroup = Map((0, Array(topNode)))
+    val treeToNodeToIndexInfo = Map((0, Map(
+      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+      )))
+    val nodeQueue = new mutable.Queue[(Int, Node)]()
+    DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+    // don't enqueue leaf nodes into node queue
+    assert(nodeQueue.isEmpty)
+
+    // set impurity and predict for topNode
+    assert(topNode.predict.predict !== Double.MinValue)
+    assert(topNode.impurity !== -1.0)
+
+    // set impurity and predict for child nodes
+    assert(topNode.leftNode.get.predict.predict === 0.0)
+    assert(topNode.rightNode.get.predict.predict === 1.0)
+    assert(topNode.leftNode.get.impurity === 0.0)
+    assert(topNode.rightNode.get.impurity === 0.0)
+  }
+
+  test("Avoid aggregation if impurity is 0.0") {
+    val arr = new Array[LabeledPoint](4)
+    arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
+    arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
+    arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
+    arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
+    val input = sc.parallelize(arr)
+
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+      numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3))
+    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+    val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+
+    val topNode = Node.emptyNode(nodeIndex = 1)
+    assert(topNode.predict.predict === Double.MinValue)
+    assert(topNode.impurity === -1.0)
+    assert(topNode.isLeaf === false)
+
+    val nodesForGroup = Map((0, Array(topNode)))
+    val treeToNodeToIndexInfo = Map((0, Map(
+      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+    )))
+    val nodeQueue = new mutable.Queue[(Int, Node)]()
+    DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+    // don't enqueue a node into node queue if its impurity is 0.0
+    assert(nodeQueue.isEmpty)
+
+    // set impurity and predict for topNode
+    assert(topNode.predict.predict !== Double.MinValue)
+    assert(topNode.impurity !== -1.0)
+
+    // set impurity and predict for child nodes
+    assert(topNode.leftNode.get.predict.predict === 0.0)
+    assert(topNode.rightNode.get.predict.predict === 1.0)
+    assert(topNode.leftNode.get.impurity === 0.0)
+    assert(topNode.rightNode.get.impurity === 0.0)
+  }
 }
 
 object DecisionTreeSuite {