From 5945bcdcc56a71324357b02c21bef80dd7efd13a Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@cs.berkeley.edu>
Date: Wed, 29 Aug 2012 23:32:08 -0700
Subject: [PATCH] Added a new flag in Aggregator to indicate applying map side
 combiners.

---
 core/src/main/scala/spark/Aggregator.scala          | 12 +++++++++++-
 core/src/main/scala/spark/ShuffledRDD.scala         | 13 ++++++-------
 .../main/scala/spark/scheduler/ShuffleMapTask.scala |  2 +-
 3 files changed, 18 insertions(+), 9 deletions(-)

diff --git a/core/src/main/scala/spark/Aggregator.scala b/core/src/main/scala/spark/Aggregator.scala
index 6f99270b1e..6516bea157 100644
--- a/core/src/main/scala/spark/Aggregator.scala
+++ b/core/src/main/scala/spark/Aggregator.scala
@@ -1,7 +1,17 @@
 package spark
 
+/** A set of functions used to aggregate data.
+  * 
+  * @param createCombiner function to create the initial value of the aggregation.
+  * @param mergeValue function to merge a new value into the aggregation result.
+  * @param mergeCombiners function to merge outputs from multiple mergeValue function.
+  * @param mapSideCombine whether to apply combiners on map partitions, also
+  *                       known as map-side aggregations. When set to false, 
+  *                       mergeCombiners function is not used.
+  */
 class Aggregator[K, V, C] (
     val createCombiner: V => C,
     val mergeValue: (C, V) => C,
-    val mergeCombiners: (C, C) => C)
+    val mergeCombiners: (C, C) => C,
+    val mapSideCombine: Boolean = true)
   extends Serializable
diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala
index 8293048caa..3616d8e47e 100644
--- a/core/src/main/scala/spark/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/ShuffledRDD.scala
@@ -29,10 +29,9 @@ class ShuffledRDD[K, V, C](
     val combiners = new JHashMap[K, C]
     val fetcher = SparkEnv.get.shuffleFetcher
 
-    if (aggregator.mergeCombiners != null) {
-      // If mergeCombiners is specified, combiners are applied on the map
-      // partitions. In this case, post-shuffle we get a list of outputs from
-      // the combiners and merge them using mergeCombiners.
+    if (aggregator.mapSideCombine) {
+      // Apply combiners on map partitions. In this case, post-shuffle we get a
+      // list of outputs from the combiners and merge them using mergeCombiners.
       def mergePairWithMapSideCombiners(k: K, c: C) {
         val oldC = combiners.get(k)
         if (oldC == null) {
@@ -43,9 +42,9 @@ class ShuffledRDD[K, V, C](
       }
       fetcher.fetch[K, C](dep.shuffleId, split.index, mergePairWithMapSideCombiners)
     } else {
-      // If mergeCombiners is not specified, no combiner is applied on the map
-      // partitions (i.e. map side aggregation is turned off). Post-shuffle we
-      // get a list of values and we use mergeValue to merge them.
+      // Do not apply combiners on map partitions (i.e. map side aggregation is
+      // turned off). Post-shuffle we get a list of values and we use mergeValue
+      // to merge them.
       def mergePairWithoutMapSideCombiners(k: K, v: V) {
         val oldC = combiners.get(k)
         if (oldC == null) {
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 940932cc51..a281ae94c5 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -108,7 +108,7 @@ class ShuffleMapTask(
     val partitioner = dep.partitioner
 
     val bucketIterators =
-      if (aggregator.mergeCombiners != null) {
+      if (aggregator.mapSideCombine) {
         // Apply combiners (map-side aggregation) to the map output.
         val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
         for (elem <- rdd.iterator(split)) {
-- 
GitLab