diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 62bf18d82d9b06e3670a8aed3a335818e962a64d..0f91c942ecd504881ba8fb2998ec165e484dd934 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -348,6 +348,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
    */
   def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
 
+  /**
+   * Reduces the elements of this RDD in a multi-level tree pattern.
+   *
+   * @param depth suggested depth of the tree
+   * @see [[org.apache.spark.api.java.JavaRDDLike#reduce]]
+   */
+  def treeReduce(f: JFunction2[T, T, T], depth: Int): T = rdd.treeReduce(f, depth)
+
+  /**
+   * [[org.apache.spark.api.java.JavaRDDLike#treeReduce]] with suggested depth 2.
+   */
+  def treeReduce(f: JFunction2[T, T, T]): T = treeReduce(f, 2)
+
   /**
    * Aggregate the elements of each partition, and then the results for all the partitions, using a
    * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
@@ -369,6 +382,30 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
     combOp: JFunction2[U, U, U]): U =
     rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U])
 
+  /**
+   * Aggregates the elements of this RDD in a multi-level tree pattern.
+   *
+   * @param depth suggested depth of the tree
+   * @see [[org.apache.spark.api.java.JavaRDDLike#aggregate]]
+   */
+  def treeAggregate[U](
+      zeroValue: U,
+      seqOp: JFunction2[U, T, U],
+      combOp: JFunction2[U, U, U],
+      depth: Int): U = {
+    rdd.treeAggregate(zeroValue)(seqOp, combOp, depth)(fakeClassTag[U])
+  }
+
+  /**
+   * [[org.apache.spark.api.java.JavaRDDLike#treeAggregate]] with suggested depth 2.
+   */
+  def treeAggregate[U](
+      zeroValue: U,
+      seqOp: JFunction2[U, T, U],
+      combOp: JFunction2[U, U, U]): U = {
+    treeAggregate(zeroValue, seqOp, combOp, 2)
+  }
+
   /**
    * Return the number of elements in the RDD.
    */
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index ab7410a1f7f99c548e9b7737ef980325194b606d..5f39384975f9b0f96360359e4168ca16bddab232 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -900,6 +900,38 @@ abstract class RDD[T: ClassTag](
     jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
   }
 
+  /**
+   * Reduces the elements of this RDD in a multi-level tree pattern.
+   *
+   * @param depth suggested depth of the tree (default: 2)
+   * @see [[org.apache.spark.rdd.RDD#reduce]]
+   */
+  def treeReduce(f: (T, T) => T, depth: Int = 2): T = {
+    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
+    val cleanF = context.clean(f)
+    val reducePartition: Iterator[T] => Option[T] = iter => {
+      if (iter.hasNext) {
+        Some(iter.reduceLeft(cleanF))
+      } else {
+        None
+      }
+    }
+    val partiallyReduced = mapPartitions(it => Iterator(reducePartition(it)))
+    val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
+      if (c.isDefined && x.isDefined) {
+        Some(cleanF(c.get, x.get))
+      } else if (c.isDefined) {
+        c
+      } else if (x.isDefined) {
+        x
+      } else {
+        None
+      }
+    }
+    partiallyReduced.treeAggregate(Option.empty[T])(op, op, depth)
+      .getOrElse(throw new UnsupportedOperationException("empty collection"))
+  }
+
   /**
    * Aggregate the elements of each partition, and then the results for all the partitions, using a
    * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
@@ -935,6 +967,37 @@ abstract class RDD[T: ClassTag](
     jobResult
   }
 
+  /**
+   * Aggregates the elements of this RDD in a multi-level tree pattern.
+   *
+   * @param depth suggested depth of the tree (default: 2)
+   * @see [[org.apache.spark.rdd.RDD#aggregate]]
+   */
+  def treeAggregate[U: ClassTag](zeroValue: U)(
+      seqOp: (U, T) => U,
+      combOp: (U, U) => U,
+      depth: Int = 2): U = {
+    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
+    if (partitions.size == 0) {
+      return Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
+    }
+    val cleanSeqOp = context.clean(seqOp)
+    val cleanCombOp = context.clean(combOp)
+    val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
+    var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
+    var numPartitions = partiallyAggregated.partitions.size
+    val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
+    // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
+    while (numPartitions > scale + numPartitions / scale) {
+      numPartitions /= scale
+      val curNumPartitions = numPartitions
+      partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
+        iter.map((i % curNumPartitions, _))
+      }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
+    }
+    partiallyAggregated.reduce(cleanCombOp)
+  }
+
   /**
    * Return the number of elements in the RDD.
    */
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 004de05c10ee18452f5b93674ed2e748dea0f215..b16a1e9460286ed4e39e4210cab7df0edde2724d 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -492,6 +492,36 @@ public class JavaAPISuite implements Serializable {
     Assert.assertEquals(33, sum);
   }
 
+  @Test
+  public void treeReduce() {
+    JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10);
+    Function2<Integer, Integer, Integer> add = new Function2<Integer, Integer, Integer>() {
+      @Override
+      public Integer call(Integer a, Integer b) {
+        return a + b;
+      }
+    };
+    for (int depth = 1; depth <= 10; depth++) {
+      int sum = rdd.treeReduce(add, depth);
+      Assert.assertEquals(-5, sum);
+    }
+  }
+
+  @Test
+  public void treeAggregate() {
+    JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10);
+    Function2<Integer, Integer, Integer> add = new Function2<Integer, Integer, Integer>() {
+      @Override
+      public Integer call(Integer a, Integer b) {
+        return a + b;
+      }
+    };
+    for (int depth = 1; depth <= 10; depth++) {
+      int sum = rdd.treeAggregate(0, add, add, depth);
+      Assert.assertEquals(-5, sum);
+    }
+  }
+
   @SuppressWarnings("unchecked")
   @Test
   public void aggregateByKey() {
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index e33b4bbbb8e4ca167b86e90ed690f7920adcdbc1..bede1ffb3e2d06210c21439e2beb442b6387f254 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -157,6 +157,24 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
   }
 
+  test("treeAggregate") {
+    val rdd = sc.makeRDD(-1000 until 1000, 10)
+    def seqOp = (c: Long, x: Int) => c + x
+    def combOp = (c1: Long, c2: Long) => c1 + c2
+    for (depth <- 1 until 10) {
+      val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth)
+      assert(sum === -1000L)
+    }
+  }
+
+  test("treeReduce") {
+    val rdd = sc.makeRDD(-1000 until 1000, 10)
+    for (depth <- 1 until 10) {
+      val sum = rdd.treeReduce(_ + _, depth)
+      assert(sum === -1000)
+    }
+  }
+
   test("basic caching") {
     val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
     assert(rdd.collect().toList === List(1, 2, 3, 4))
@@ -967,4 +985,5 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     assertFails { sc.parallelize(1 to 100) }
     assertFails { sc.textFile("/nonexistent-path") }
   }
+
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
index 3260f27513c7f78dee5c482fcf7038632cb9ebec..a89eea0e21be2ad0298b1ae23deaafbbf38acad4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
@@ -22,7 +22,6 @@ import breeze.linalg.{DenseVector => BDV}
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
-import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.rdd.RDD
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
index 3c2091732f9b06698db490ce5d2ac0b2d5810007..2f2c6f94e909522e5e32676fffb95257b302d495 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
@@ -20,7 +20,6 @@ package org.apache.spark.mllib.feature
 import org.apache.spark.Logging
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
-import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.rdd.RDD
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 02075edbabf859d04976742d0f8468dece606aab..ddca30c3c01c81b057f461eabe639033b243bc73 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -30,7 +30,6 @@ import org.apache.spark.Logging
 import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.linalg._
-import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.random.XORShiftRandom
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index 0857877951c82dc557e8990f6f3f19445e04ac4f..4b7d0589c973bfb6774aa40e8bf0c3ed10181a8b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -25,7 +25,6 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi}
 import org.apache.spark.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.mllib.linalg.{Vectors, Vector}
-import org.apache.spark.mllib.rdd.RDDFunctions._
 
 /**
  * Class used to solve an optimization problem using Gradient Descent.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index d16d0daf085650d4218825f37043555585b43c84..d5e4f4ccbff10d108d7a9337ea1bf253f4606046 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -26,7 +26,6 @@ import org.apache.spark.Logging
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.linalg.BLAS.axpy
-import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.rdd.RDD
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
index 57c0768084e41f2545b2118bb40069bead1b2ce1..78172843be56e7fac61505cbe53018b8d4210abc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
@@ -21,10 +21,7 @@ import scala.language.implicitConversions
 import scala.reflect.ClassTag
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.HashPartitioner
-import org.apache.spark.SparkContext._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.util.Utils
 
 /**
  * Machine learning specific RDD functions.
@@ -53,63 +50,25 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable {
    * Reduces the elements of this RDD in a multi-level tree pattern.
    *
    * @param depth suggested depth of the tree (default: 2)
-   * @see [[org.apache.spark.rdd.RDD#reduce]]
+   * @see [[org.apache.spark.rdd.RDD#treeReduce]]
+   * @deprecated Use [[org.apache.spark.rdd.RDD#treeReduce]] instead.
    */
-  def treeReduce(f: (T, T) => T, depth: Int = 2): T = {
-    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
-    val cleanF = self.context.clean(f)
-    val reducePartition: Iterator[T] => Option[T] = iter => {
-      if (iter.hasNext) {
-        Some(iter.reduceLeft(cleanF))
-      } else {
-        None
-      }
-    }
-    val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it)))
-    val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
-      if (c.isDefined && x.isDefined) {
-        Some(cleanF(c.get, x.get))
-      } else if (c.isDefined) {
-        c
-      } else if (x.isDefined) {
-        x
-      } else {
-        None
-      }
-    }
-    RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth)
-      .getOrElse(throw new UnsupportedOperationException("empty collection"))
-  }
+  @deprecated("Use RDD.treeReduce instead.", "1.3.0")
+  def treeReduce(f: (T, T) => T, depth: Int = 2): T = self.treeReduce(f, depth)
 
   /**
    * Aggregates the elements of this RDD in a multi-level tree pattern.
    *
    * @param depth suggested depth of the tree (default: 2)
-   * @see [[org.apache.spark.rdd.RDD#aggregate]]
+   * @see [[org.apache.spark.rdd.RDD#treeAggregate]]
+   * @deprecated Use [[org.apache.spark.rdd.RDD#treeAggregate]] instead.
    */
+  @deprecated("Use RDD.treeAggregate instead.", "1.3.0")
   def treeAggregate[U: ClassTag](zeroValue: U)(
       seqOp: (U, T) => U,
       combOp: (U, U) => U,
       depth: Int = 2): U = {
-    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
-    if (self.partitions.size == 0) {
-      return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance())
-    }
-    val cleanSeqOp = self.context.clean(seqOp)
-    val cleanCombOp = self.context.clean(combOp)
-    val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
-    var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it)))
-    var numPartitions = partiallyAggregated.partitions.size
-    val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
-    // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
-    while (numPartitions > scale + numPartitions / scale) {
-      numPartitions /= scale
-      val curNumPartitions = numPartitions
-      partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
-        iter.map((i % curNumPartitions, _))
-      }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
-    }
-    partiallyAggregated.reduce(cleanCombOp)
+    self.treeAggregate(zeroValue)(seqOp, combOp, depth)
   }
 }
 
@@ -117,5 +76,5 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable {
 object RDDFunctions {
 
   /** Implicit conversion from an RDD to RDDFunctions. */
-  implicit def fromRDD[T: ClassTag](rdd: RDD[T]) = new RDDFunctions[T](rdd)
+  implicit def fromRDD[T: ClassTag](rdd: RDD[T]): RDDFunctions[T] = new RDDFunctions[T](rdd)
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
index 4c93c0ca4f86ce7b0beb8f2ce76d1171a844125c..e9e510b6f55462f3206688ba99e1151e5670857f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
@@ -22,7 +22,6 @@ import org.scalatest.FunSuite
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
 import org.apache.spark.rdd.RDD
 
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
index 681ce9263933bf751e6428cea7f61ce180af58ad..6d6c0aa5be81291ab52595c2cd3bee687746c600 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
@@ -46,22 +46,4 @@ class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {
     val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq)
     assert(sliding === expected)
   }
-
-  test("treeAggregate") {
-    val rdd = sc.makeRDD(-1000 until 1000, 10)
-    def seqOp = (c: Long, x: Int) => c + x
-    def combOp = (c1: Long, c2: Long) => c1 + c2
-    for (depth <- 1 until 10) {
-      val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth)
-      assert(sum === -1000L)
-    }
-  }
-
-  test("treeReduce") {
-    val rdd = sc.makeRDD(-1000 until 1000, 10)
-    for (depth <- 1 until 10) {
-      val sum = rdd.treeReduce(_ + _, depth)
-      assert(sum === -1000)
-    }
-  }
 }
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index e750fed7448cd6d5aa0a05619a463319fc322779..14ba03ed4634b76521b5a338f144929d894264cd 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -113,6 +113,12 @@ object MimaExcludes {
             // SPARK-5270
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.api.java.JavaRDDLike.isEmpty")
+          ) ++ Seq(
+            // SPARK-5430
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.api.java.JavaRDDLike.treeReduce"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.api.java.JavaRDDLike.treeAggregate")
           ) ++ Seq(
             // SPARK-5297 Java FileStream do not work with custom key/values
             ProblemFilters.exclude[MissingMethodProblem](
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index b6dd5a3bf028dae8dbefc649575aec72fb695b34..2f8a0edfe964409d23604db42442f118805f6a20 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -29,7 +29,7 @@ import warnings
 import heapq
 import bisect
 import random
-from math import sqrt, log, isinf, isnan
+from math import sqrt, log, isinf, isnan, pow, ceil
 
 from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
     BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
@@ -726,6 +726,43 @@ class RDD(object):
             return reduce(f, vals)
         raise ValueError("Can not reduce() empty RDD")
 
+    def treeReduce(self, f, depth=2):
+        """
+        Reduces the elements of this RDD in a multi-level tree pattern.
+
+        :param depth: suggested depth of the tree (default: 2)
+
+        >>> add = lambda x, y: x + y
+        >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10)
+        >>> rdd.treeReduce(add)
+        -5
+        >>> rdd.treeReduce(add, 1)
+        -5
+        >>> rdd.treeReduce(add, 2)
+        -5
+        >>> rdd.treeReduce(add, 5)
+        -5
+        >>> rdd.treeReduce(add, 10)
+        -5
+        """
+        if depth < 1:
+            raise ValueError("Depth cannot be smaller than 1 but got %d." % depth)
+
+        zeroValue = None, True  # Use the second entry to indicate whether this is a dummy value.
+
+        def op(x, y):
+            if x[1]:
+                return y
+            elif y[1]:
+                return x
+            else:
+                return f(x[0], y[0]), False
+
+        reduced = self.map(lambda x: (x, False)).treeAggregate(zeroValue, op, op, depth)
+        if reduced[1]:
+            raise ValueError("Cannot reduce empty RDD.")
+        return reduced[0]
+
     def fold(self, zeroValue, op):
         """
         Aggregate the elements of each partition, and then the results for all
@@ -777,6 +814,58 @@ class RDD(object):
 
         return self.mapPartitions(func).fold(zeroValue, combOp)
 
+    def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
+        """
+        Aggregates the elements of this RDD in a multi-level tree
+        pattern.
+
+        :param depth: suggested depth of the tree (default: 2)
+
+        >>> add = lambda x, y: x + y
+        >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10)
+        >>> rdd.treeAggregate(0, add, add)
+        -5
+        >>> rdd.treeAggregate(0, add, add, 1)
+        -5
+        >>> rdd.treeAggregate(0, add, add, 2)
+        -5
+        >>> rdd.treeAggregate(0, add, add, 5)
+        -5
+        >>> rdd.treeAggregate(0, add, add, 10)
+        -5
+        """
+        if depth < 1:
+            raise ValueError("Depth cannot be smaller than 1 but got %d." % depth)
+
+        if self.getNumPartitions() == 0:
+            return zeroValue
+
+        def aggregatePartition(iterator):
+            acc = zeroValue
+            for obj in iterator:
+                acc = seqOp(acc, obj)
+            yield acc
+
+        partiallyAggregated = self.mapPartitions(aggregatePartition)
+        numPartitions = partiallyAggregated.getNumPartitions()
+        scale = max(int(ceil(pow(numPartitions, 1.0 / depth))), 2)
+        # If creating an extra level doesn't help reduce the wall-clock time, we stop the tree
+        # aggregation.
+        while numPartitions > scale + numPartitions / scale:
+            numPartitions /= scale
+            curNumPartitions = numPartitions
+
+            def mapPartition(i, iterator):
+                for obj in iterator:
+                    yield (i % curNumPartitions, obj)
+
+            partiallyAggregated = partiallyAggregated \
+                .mapPartitionsWithIndex(mapPartition) \
+                .reduceByKey(combOp, curNumPartitions) \
+                .values()
+
+        return partiallyAggregated.reduce(combOp)
+
     def max(self, key=None):
         """
         Find the maximum item in this RDD.