diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 8c2b044ea73f215f2ab205f0f3f56fb7f245dbaf..58c1322757a43e76ed822075393a1b0c04cc8ba9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -28,6 +28,7 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.linalg._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.Logging
+import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
 
 /**
@@ -79,7 +80,7 @@ class RowMatrix(
   private[mllib] def multiplyGramianMatrixBy(v: BDV[Double]): BDV[Double] = {
     val n = numCols().toInt
     val vbr = rows.context.broadcast(v)
-    rows.aggregate(BDV.zeros[Double](n))(
+    rows.treeAggregate(BDV.zeros[Double](n))(
       seqOp = (U, r) => {
         val rBrz = r.toBreeze
         val a = rBrz.dot(vbr.value)
@@ -91,9 +92,7 @@ class RowMatrix(
             s"Do not support vector operation from type ${rBrz.getClass.getName}.")
         }
         U
-      },
-      combOp = (U1, U2) => U1 += U2
-    )
+      }, combOp = (U1, U2) => U1 += U2)
   }
 
   /**
@@ -104,13 +103,11 @@ class RowMatrix(
     val nt: Int = n * (n + 1) / 2
 
     // Compute the upper triangular part of the gram matrix.
-    val GU = rows.aggregate(new BDV[Double](new Array[Double](nt)))(
+    val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))(
       seqOp = (U, v) => {
         RowMatrix.dspr(1.0, v, U.data)
         U
-      },
-      combOp = (U1, U2) => U1 += U2
-    )
+      }, combOp = (U1, U2) => U1 += U2)
 
     RowMatrix.triuToFull(n, GU.data)
   }
@@ -290,9 +287,10 @@ class RowMatrix(
         s"We need at least $mem bytes of memory.")
     }
 
-    val (m, mean) = rows.aggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))(
+    val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))(
       seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze),
-      combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2)
+      combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) =>
+        (s1._1 + s2._1, s1._2 += s2._2)
     )
 
     updateNumRows(m)
@@ -353,10 +351,9 @@ class RowMatrix(
    * Computes column-wise summary statistics.
    */
   def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
-    val summary = rows.aggregate[MultivariateOnlineSummarizer](new MultivariateOnlineSummarizer)(
+    val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)(
       (aggregator, data) => aggregator.add(data),
-      (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
-    )
+      (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
     updateNumRows(summary.count)
     summary
   }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index 9fd760bf78083756f7b91ad200b0bc1e816e8860..356aa949afcf5c530f5dab24d2b16a13409193db 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -25,6 +25,7 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi}
 import org.apache.spark.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.mllib.linalg.{Vectors, Vector}
+import org.apache.spark.mllib.rdd.RDDFunctions._
 
 /**
  * Class used to solve an optimization problem using Gradient Descent.
@@ -177,7 +178,7 @@ object GradientDescent extends Logging {
       // Sample a subset (fraction miniBatchFraction) of the total data
       // compute and sum up the subgradients on this subset (this is one map-reduce)
       val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
-        .aggregate((BDV.zeros[Double](n), 0.0))(
+        .treeAggregate((BDV.zeros[Double](n), 0.0))(
           seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
             val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad))
             (grad, loss + l)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index 179cd4a3f16255f0ea394e171855c83b1b3abcbf..26a2b62e76ed0a1d078594e918e1529655682733 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -26,6 +26,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.mllib.linalg.{Vectors, Vector}
+import org.apache.spark.mllib.rdd.RDDFunctions._
 
 /**
  * :: DeveloperApi ::
@@ -199,7 +200,7 @@ object LBFGS extends Logging {
       val n = weights.length
       val bcWeights = data.context.broadcast(weights)
 
-      val (gradientSum, lossSum) = data.aggregate((BDV.zeros[Double](n), 0.0))(
+      val (gradientSum, lossSum) = data.treeAggregate((BDV.zeros[Double](n), 0.0))(
           seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
             val l = localGradient.compute(
               features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
index 365b5e75d7f7536d4cdea0e134a1987503271646..b5e403bc8c14d43fef2b6808ef2ac19abaaa31ba 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
@@ -20,7 +20,10 @@ package org.apache.spark.mllib.rdd
 import scala.language.implicitConversions
 import scala.reflect.ClassTag
 
+import org.apache.spark.HashPartitioner
+import org.apache.spark.SparkContext._
 import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
 
 /**
  * Machine learning specific RDD functions.
@@ -44,6 +47,69 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
       new SlidingRDD[T](self, windowSize)
     }
   }
+
+  /**
+   * Reduces the elements of this RDD in a multi-level tree pattern.
+   *
+   * @param depth suggested depth of the tree (default: 2)
+   * @see [[org.apache.spark.rdd.RDD#reduce]]
+   */
+  def treeReduce(f: (T, T) => T, depth: Int = 2): T = {
+    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
+    val cleanF = self.context.clean(f)
+    val reducePartition: Iterator[T] => Option[T] = iter => {
+      if (iter.hasNext) {
+        Some(iter.reduceLeft(cleanF))
+      } else {
+        None
+      }
+    }
+    val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it)))
+    val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
+      if (c.isDefined && x.isDefined) {
+        Some(cleanF(c.get, x.get))
+      } else if (c.isDefined) {
+        c
+      } else if (x.isDefined) {
+        x
+      } else {
+        None
+      }
+    }
+    RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth)
+      .getOrElse(throw new UnsupportedOperationException("empty collection"))
+  }
+
+  /**
+   * Aggregates the elements of this RDD in a multi-level tree pattern.
+   *
+   * @param depth suggested depth of the tree (default: 2)
+   * @see [[org.apache.spark.rdd.RDD#aggregate]]
+   */
+  def treeAggregate[U: ClassTag](zeroValue: U)(
+      seqOp: (U, T) => U,
+      combOp: (U, U) => U,
+      depth: Int = 2): U = {
+    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
+    if (self.partitions.size == 0) {
+      return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance())
+    }
+    val cleanSeqOp = self.context.clean(seqOp)
+    val cleanCombOp = self.context.clean(combOp)
+    val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
+    var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it)))
+    var numPartitions = partiallyAggregated.partitions.size
+    val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
+    // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
+    while (numPartitions > scale + numPartitions / scale) {
+      numPartitions /= scale
+      val curNumPartitions = numPartitions
+      partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
+        iter.map((i % curNumPartitions, _))
+      }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
+    }
+    partiallyAggregated.reduce(cleanCombOp)
+  }
 }
 
 private[mllib]
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
index 3f3b10dfff35ebfa288d592e2490727f48889b10..27a19f793242b33f26cc5461ab9b82187ef7ad0d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
@@ -46,4 +46,22 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext {
     val expected = data.flatMap(x => x).sliding(3).toList
     assert(sliding.collect().toList === expected)
   }
+
+  test("treeAggregate") {
+    val rdd = sc.makeRDD(-1000 until 1000, 10)
+    def seqOp = (c: Long, x: Int) => c + x
+    def combOp = (c1: Long, c2: Long) => c1 + c2
+    for (depth <- 1 until 10) {
+      val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth)
+      assert(sum === -1000L)
+    }
+  }
+
+  test("treeReduce") {
+    val rdd = sc.makeRDD(-1000 until 1000, 10)
+    for (depth <- 1 until 10) {
+      val sum = rdd.treeReduce(_ + _, depth)
+      assert(sum === -1000)
+    }
+  }
 }