Skip to content
Snippets Groups Projects
Commit c34b8ad2 authored by Stephen Haberman's avatar Stephen Haberman
Browse files

Avoid a shuffle if combineByKey is passed the same partitioner.

parent beb7ab87
No related branches found
No related tags found
No related merge requests found
......@@ -62,7 +62,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (mapSideCombine) {
if (Option(partitioner) == self.partitioner) {
self.mapPartitions(aggregator.combineValuesByKey(_), true)
} else if (mapSideCombine) {
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
......
......@@ -98,6 +98,19 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
val sums = pairs.reduceByKey(_+_, 10).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
}
test("reduceByKey with partitioner") {
sc = new SparkContext("local", "test")
val p = new Partitioner() {
def numPartitions = 2
def getPartition(key: Any) = key.asInstanceOf[Int]
}
val pairs = rddToPairRDDFunctions(sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1)))).partitionBy(p)
val sums = pairs.reduceByKey(p, _+_)
println(sums.toDebugString)
assert(sums.collect().toSet === Set((1, 4), (0, 1)))
assert(sums.partitioner === Some(p))
}
test("join") {
sc = new SparkContext("local", "test")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment