From a8a2a08a1a7e652920702f25a89e43788d538d05 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@cs.berkeley.edu>
Date: Thu, 30 Aug 2012 12:34:28 -0700
Subject: [PATCH] Added a test for testing map-side combine on/off switch.

---
 core/src/test/scala/spark/ShuffleSuite.scala | 45 +++++++++++++++++++-
 1 file changed, 44 insertions(+), 1 deletion(-)

diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 99d13b31ef..f622c413f7 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -2,6 +2,7 @@ package spark
 
 import org.scalatest.FunSuite
 import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.ShouldMatchers
 import org.scalatest.prop.Checkers
 import org.scalacheck.Arbitrary._
 import org.scalacheck.Gen
@@ -13,7 +14,7 @@ import scala.collection.mutable.ArrayBuffer
 
 import SparkContext._
 
-class ShuffleSuite extends FunSuite with BeforeAndAfter {
+class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
   
   var sc: SparkContext = _
   
@@ -196,4 +197,46 @@ class ShuffleSuite extends FunSuite with BeforeAndAfter {
     // Test that a shuffle on the file works, because this used to be a bug
     assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)    
   }
+
+  test("map-side combine") {
+    sc = new SparkContext("local", "test")
+    val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1), (1, 1)), 2)
+
+    // Test with map-side combine on.
+    val sums = pairs.reduceByKey(_+_).collect()
+    assert(sums.toSet === Set((1, 8), (2, 1)))
+
+    // Turn off map-side combine and test the results.
+    val aggregator = new Aggregator[Int, Int, Int](
+      (v: Int) => v,
+      _+_,
+      _+_,
+      false)
+    val shuffledRdd = new ShuffledRDD(
+      pairs, aggregator, new HashPartitioner(2))
+    assert(shuffledRdd.collect().toSet === Set((1, 8), (2, 1)))
+
+    // Turn map-side combine off and pass a wrong mergeCombine function. Should
+    // not see an exception because mergeCombine should not have been called.
+    val aggregatorWithException = new Aggregator[Int, Int, Int](
+      (v: Int) => v, _+_, ShuffleSuite.mergeCombineException, false)
+    val shuffledRdd1 = new ShuffledRDD(
+      pairs, aggregatorWithException, new HashPartitioner(2))
+    assert(shuffledRdd1.collect().toSet === Set((1, 8), (2, 1)))
+
+    // Now run the same mergeCombine function with map-side combine on. We
+    // expect to see an exception thrown.
+    val aggregatorWithException1 = new Aggregator[Int, Int, Int](
+      (v: Int) => v, _+_, ShuffleSuite.mergeCombineException)
+    val shuffledRdd2 = new ShuffledRDD(
+      pairs, aggregatorWithException1, new HashPartitioner(2))
+    evaluating { shuffledRdd2.collect() } should produce [SparkException]
+  }
+}
+
+object ShuffleSuite {
+  def mergeCombineException(x: Int, y: Int): Int = {
+    throw new SparkException("Exception for map-side combine.")
+    x + y
+  }
 }
-- 
GitLab