diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 62bf18d82d9b06e3670a8aed3a335818e962a64d..0f91c942ecd504881ba8fb2998ec165e484dd934 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -348,6 +348,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f) + /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree + * @see [[org.apache.spark.api.java.JavaRDDLike#reduce]] + */ + def treeReduce(f: JFunction2[T, T, T], depth: Int): T = rdd.treeReduce(f, depth) + + /** + * [[org.apache.spark.api.java.JavaRDDLike#treeReduce]] with suggested depth 2. + */ + def treeReduce(f: JFunction2[T, T, T]): T = treeReduce(f, 2) + /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to @@ -369,6 +382,30 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { combOp: JFunction2[U, U, U]): U = rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U]) + /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree + * @see [[org.apache.spark.api.java.JavaRDDLike#aggregate]] + */ + def treeAggregate[U]( + zeroValue: U, + seqOp: JFunction2[U, T, U], + combOp: JFunction2[U, U, U], + depth: Int): U = { + rdd.treeAggregate(zeroValue)(seqOp, combOp, depth)(fakeClassTag[U]) + } + + /** + * [[org.apache.spark.api.java.JavaRDDLike#treeAggregate]] with suggested depth 2. + */ + def treeAggregate[U]( + zeroValue: U, + seqOp: JFunction2[U, T, U], + combOp: JFunction2[U, U, U]): U = { + treeAggregate(zeroValue, seqOp, combOp, 2) + } + /** * Return the number of elements in the RDD. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ab7410a1f7f99c548e9b7737ef980325194b606d..5f39384975f9b0f96360359e4168ca16bddab232 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -900,6 +900,38 @@ abstract class RDD[T: ClassTag]( jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) } + /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#reduce]] + */ + def treeReduce(f: (T, T) => T, depth: Int = 2): T = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + val cleanF = context.clean(f) + val reducePartition: Iterator[T] => Option[T] = iter => { + if (iter.hasNext) { + Some(iter.reduceLeft(cleanF)) + } else { + None + } + } + val partiallyReduced = mapPartitions(it => Iterator(reducePartition(it))) + val op: (Option[T], Option[T]) => Option[T] = (c, x) => { + if (c.isDefined && x.isDefined) { + Some(cleanF(c.get, x.get)) + } else if (c.isDefined) { + c + } else if (x.isDefined) { + x + } else { + None + } + } + partiallyReduced.treeAggregate(Option.empty[T])(op, op, depth) + .getOrElse(throw new UnsupportedOperationException("empty collection")) + } + /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to @@ -935,6 +967,37 @@ abstract class RDD[T: ClassTag]( jobResult } + /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#aggregate]] + */ + def treeAggregate[U: ClassTag](zeroValue: U)( + seqOp: (U, T) => U, + combOp: (U, U) => U, + depth: Int = 2): U = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + if (partitions.size == 0) { + return Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) + } + val cleanSeqOp = context.clean(seqOp) + val cleanCombOp = context.clean(combOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) + var numPartitions = partiallyAggregated.partitions.size + val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) + // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. + while (numPartitions > scale + numPartitions / scale) { + numPartitions /= scale + val curNumPartitions = numPartitions + partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => + iter.map((i % curNumPartitions, _)) + }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + } + partiallyAggregated.reduce(cleanCombOp) + } + /** * Return the number of elements in the RDD. */ diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 004de05c10ee18452f5b93674ed2e748dea0f215..b16a1e9460286ed4e39e4210cab7df0edde2724d 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -492,6 +492,36 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(33, sum); } + @Test + public void treeReduce() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); + Function2<Integer, Integer, Integer> add = new Function2<Integer, Integer, Integer>() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }; + for (int depth = 1; depth <= 10; depth++) { + int sum = rdd.treeReduce(add, depth); + Assert.assertEquals(-5, sum); + } + } + + @Test + public void treeAggregate() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); + Function2<Integer, Integer, Integer> add = new Function2<Integer, Integer, Integer>() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }; + for (int depth = 1; depth <= 10; depth++) { + int sum = rdd.treeAggregate(0, add, add, depth); + Assert.assertEquals(-5, sum); + } + } + @SuppressWarnings("unchecked") @Test public void aggregateByKey() { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e33b4bbbb8e4ca167b86e90ed690f7920adcdbc1..bede1ffb3e2d06210c21439e2beb442b6387f254 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -157,6 +157,24 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } + test("treeAggregate") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + def seqOp = (c: Long, x: Int) => c + x + def combOp = (c1: Long, c2: Long) => c1 + c2 + for (depth <- 1 until 10) { + val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) + assert(sum === -1000L) + } + } + + test("treeReduce") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + for (depth <- 1 until 10) { + val sum = rdd.treeReduce(_ + _, depth) + assert(sum === -1000) + } + } + test("basic caching") { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(rdd.collect().toList === List(1, 2, 3, 4)) @@ -967,4 +985,5 @@ class RDDSuite extends FunSuite with SharedSparkContext { assertFails { sc.parallelize(1 to 100) } assertFails { sc.textFile("/nonexistent-path") } } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index 3260f27513c7f78dee5c482fcf7038632cb9ebec..a89eea0e21be2ad0298b1ae23deaafbbf38acad4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -22,7 +22,6 @@ import breeze.linalg.{DenseVector => BDV} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 3c2091732f9b06698db490ce5d2ac0b2d5810007..2f2c6f94e909522e5e32676fffb95257b302d495 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -20,7 +20,6 @@ package org.apache.spark.mllib.feature import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 02075edbabf859d04976742d0f8468dece606aab..ddca30c3c01c81b057f461eabe639033b243bc73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -30,7 +30,6 @@ import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg._ -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 0857877951c82dc557e8990f6f3f19445e04ac4f..4b7d0589c973bfb6774aa40e8bf0c3ed10181a8b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -25,7 +25,6 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} -import org.apache.spark.mllib.rdd.RDDFunctions._ /** * Class used to solve an optimization problem using Gradient Descent. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index d16d0daf085650d4218825f37043555585b43c84..d5e4f4ccbff10d108d7a9337ea1bf253f4606046 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -26,7 +26,6 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.axpy -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 57c0768084e41f2545b2118bb40069bead1b2ce1..78172843be56e7fac61505cbe53018b8d4210abc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -21,10 +21,7 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.HashPartitioner -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils /** * Machine learning specific RDD functions. @@ -53,63 +50,25 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { * Reduces the elements of this RDD in a multi-level tree pattern. * * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#reduce]] + * @see [[org.apache.spark.rdd.RDD#treeReduce]] + * @deprecated Use [[org.apache.spark.rdd.RDD#treeReduce]] instead. */ - def treeReduce(f: (T, T) => T, depth: Int = 2): T = { - require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - val cleanF = self.context.clean(f) - val reducePartition: Iterator[T] => Option[T] = iter => { - if (iter.hasNext) { - Some(iter.reduceLeft(cleanF)) - } else { - None - } - } - val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it))) - val op: (Option[T], Option[T]) => Option[T] = (c, x) => { - if (c.isDefined && x.isDefined) { - Some(cleanF(c.get, x.get)) - } else if (c.isDefined) { - c - } else if (x.isDefined) { - x - } else { - None - } - } - RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth) - .getOrElse(throw new UnsupportedOperationException("empty collection")) - } + @deprecated("Use RDD.treeReduce instead.", "1.3.0") + def treeReduce(f: (T, T) => T, depth: Int = 2): T = self.treeReduce(f, depth) /** * Aggregates the elements of this RDD in a multi-level tree pattern. * * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#aggregate]] + * @see [[org.apache.spark.rdd.RDD#treeAggregate]] + * @deprecated Use [[org.apache.spark.rdd.RDD#treeAggregate]] instead. */ + @deprecated("Use RDD.treeAggregate instead.", "1.3.0") def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, T) => U, combOp: (U, U) => U, depth: Int = 2): U = { - require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - if (self.partitions.size == 0) { - return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance()) - } - val cleanSeqOp = self.context.clean(seqOp) - val cleanCombOp = self.context.clean(combOp) - val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) - var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it))) - var numPartitions = partiallyAggregated.partitions.size - val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) - // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. - while (numPartitions > scale + numPartitions / scale) { - numPartitions /= scale - val curNumPartitions = numPartitions - partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => - iter.map((i % curNumPartitions, _)) - }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values - } - partiallyAggregated.reduce(cleanCombOp) + self.treeAggregate(zeroValue)(seqOp, combOp, depth) } } @@ -117,5 +76,5 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { object RDDFunctions { /** Implicit conversion from an RDD to RDDFunctions. */ - implicit def fromRDD[T: ClassTag](rdd: RDD[T]) = new RDDFunctions[T](rdd) + implicit def fromRDD[T: ClassTag](rdd: RDD[T]): RDDFunctions[T] = new RDDFunctions[T](rdd) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 4c93c0ca4f86ce7b0beb8f2ce76d1171a844125c..e9e510b6f55462f3206688ba99e1151e5670857f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -22,7 +22,6 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 681ce9263933bf751e6428cea7f61ce180af58ad..6d6c0aa5be81291ab52595c2cd3bee687746c600 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -46,22 +46,4 @@ class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) assert(sliding === expected) } - - test("treeAggregate") { - val rdd = sc.makeRDD(-1000 until 1000, 10) - def seqOp = (c: Long, x: Int) => c + x - def combOp = (c1: Long, c2: Long) => c1 + c2 - for (depth <- 1 until 10) { - val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) - assert(sum === -1000L) - } - } - - test("treeReduce") { - val rdd = sc.makeRDD(-1000 until 1000, 10) - for (depth <- 1 until 10) { - val sum = rdd.treeReduce(_ + _, depth) - assert(sum === -1000) - } - } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index e750fed7448cd6d5aa0a05619a463319fc322779..14ba03ed4634b76521b5a338f144929d894264cd 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -113,6 +113,12 @@ object MimaExcludes { // SPARK-5270 ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.isEmpty") + ) ++ Seq( + // SPARK-5430 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeReduce"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeAggregate") ) ++ Seq( // SPARK-5297 Java FileStream do not work with custom key/values ProblemFilters.exclude[MissingMethodProblem]( diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b6dd5a3bf028dae8dbefc649575aec72fb695b34..2f8a0edfe964409d23604db42442f118805f6a20 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -29,7 +29,7 @@ import warnings import heapq import bisect import random -from math import sqrt, log, isinf, isnan +from math import sqrt, log, isinf, isnan, pow, ceil from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ @@ -726,6 +726,43 @@ class RDD(object): return reduce(f, vals) raise ValueError("Can not reduce() empty RDD") + def treeReduce(self, f, depth=2): + """ + Reduces the elements of this RDD in a multi-level tree pattern. + + :param depth: suggested depth of the tree (default: 2) + + >>> add = lambda x, y: x + y + >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10) + >>> rdd.treeReduce(add) + -5 + >>> rdd.treeReduce(add, 1) + -5 + >>> rdd.treeReduce(add, 2) + -5 + >>> rdd.treeReduce(add, 5) + -5 + >>> rdd.treeReduce(add, 10) + -5 + """ + if depth < 1: + raise ValueError("Depth cannot be smaller than 1 but got %d." % depth) + + zeroValue = None, True # Use the second entry to indicate whether this is a dummy value. + + def op(x, y): + if x[1]: + return y + elif y[1]: + return x + else: + return f(x[0], y[0]), False + + reduced = self.map(lambda x: (x, False)).treeAggregate(zeroValue, op, op, depth) + if reduced[1]: + raise ValueError("Cannot reduce empty RDD.") + return reduced[0] + def fold(self, zeroValue, op): """ Aggregate the elements of each partition, and then the results for all @@ -777,6 +814,58 @@ class RDD(object): return self.mapPartitions(func).fold(zeroValue, combOp) + def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): + """ + Aggregates the elements of this RDD in a multi-level tree + pattern. + + :param depth: suggested depth of the tree (default: 2) + + >>> add = lambda x, y: x + y + >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10) + >>> rdd.treeAggregate(0, add, add) + -5 + >>> rdd.treeAggregate(0, add, add, 1) + -5 + >>> rdd.treeAggregate(0, add, add, 2) + -5 + >>> rdd.treeAggregate(0, add, add, 5) + -5 + >>> rdd.treeAggregate(0, add, add, 10) + -5 + """ + if depth < 1: + raise ValueError("Depth cannot be smaller than 1 but got %d." % depth) + + if self.getNumPartitions() == 0: + return zeroValue + + def aggregatePartition(iterator): + acc = zeroValue + for obj in iterator: + acc = seqOp(acc, obj) + yield acc + + partiallyAggregated = self.mapPartitions(aggregatePartition) + numPartitions = partiallyAggregated.getNumPartitions() + scale = max(int(ceil(pow(numPartitions, 1.0 / depth))), 2) + # If creating an extra level doesn't help reduce the wall-clock time, we stop the tree + # aggregation. + while numPartitions > scale + numPartitions / scale: + numPartitions /= scale + curNumPartitions = numPartitions + + def mapPartition(i, iterator): + for obj in iterator: + yield (i % curNumPartitions, obj) + + partiallyAggregated = partiallyAggregated \ + .mapPartitionsWithIndex(mapPartition) \ + .reduceByKey(combOp, curNumPartitions) \ + .values() + + return partiallyAggregated.reduce(combOp) + def max(self, key=None): """ Find the maximum item in this RDD.