From c34b8ad2c59697b3e1f5034074e5de0d3b32b8f9 Mon Sep 17 00:00:00 2001
From: Stephen Haberman <stephen@exigencecorp.com>
Date: Sat, 16 Feb 2013 00:54:03 -0600
Subject: [PATCH] Avoid a shuffle if combineByKey is passed the same
 partitioner.

---
 core/src/main/scala/spark/PairRDDFunctions.scala |  4 +++-
 core/src/test/scala/spark/ShuffleSuite.scala     | 13 +++++++++++++
 2 files changed, 16 insertions(+), 1 deletion(-)

diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index cc3cca2571..4c41519330 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -62,7 +62,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
     }
     val aggregator =
       new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
-    if (mapSideCombine) {
+    if (Option(partitioner) == self.partitioner) {
+      self.mapPartitions(aggregator.combineValuesByKey(_), true)
+    } else if (mapSideCombine) {
       val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
       val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
       partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 3493b9511f..d6efa3db43 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -98,6 +98,19 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
     val sums = pairs.reduceByKey(_+_, 10).collect()
     assert(sums.toSet === Set((1, 7), (2, 1)))
   }
+  
+  test("reduceByKey with partitioner") {
+    sc = new SparkContext("local", "test")
+    val p = new Partitioner() {
+      def numPartitions = 2
+      def getPartition(key: Any) = key.asInstanceOf[Int]
+    }
+    val pairs = rddToPairRDDFunctions(sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1)))).partitionBy(p)
+    val sums = pairs.reduceByKey(p, _+_)
+    println(sums.toDebugString)
+    assert(sums.collect().toSet === Set((1, 4), (0, 1)))
+    assert(sums.partitioner === Some(p))
+  }
 
   test("join") {
     sc = new SparkContext("local", "test")
-- 
GitLab