diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 99d13b31ef62cadf756fc6318dfd309ef3e43c47..f622c413f7c5f8a51707f4edecb8eb4fc905e1d5 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -2,6 +2,7 @@ package spark import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter +import org.scalatest.matchers.ShouldMatchers import org.scalatest.prop.Checkers import org.scalacheck.Arbitrary._ import org.scalacheck.Gen @@ -13,7 +14,7 @@ import scala.collection.mutable.ArrayBuffer import SparkContext._ -class ShuffleSuite extends FunSuite with BeforeAndAfter { +class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { var sc: SparkContext = _ @@ -196,4 +197,46 @@ class ShuffleSuite extends FunSuite with BeforeAndAfter { // Test that a shuffle on the file works, because this used to be a bug assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) } + + test("map-side combine") { + sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1), (1, 1)), 2) + + // Test with map-side combine on. + val sums = pairs.reduceByKey(_+_).collect() + assert(sums.toSet === Set((1, 8), (2, 1))) + + // Turn off map-side combine and test the results. + val aggregator = new Aggregator[Int, Int, Int]( + (v: Int) => v, + _+_, + _+_, + false) + val shuffledRdd = new ShuffledRDD( + pairs, aggregator, new HashPartitioner(2)) + assert(shuffledRdd.collect().toSet === Set((1, 8), (2, 1))) + + // Turn map-side combine off and pass a wrong mergeCombine function. Should + // not see an exception because mergeCombine should not have been called. + val aggregatorWithException = new Aggregator[Int, Int, Int]( + (v: Int) => v, _+_, ShuffleSuite.mergeCombineException, false) + val shuffledRdd1 = new ShuffledRDD( + pairs, aggregatorWithException, new HashPartitioner(2)) + assert(shuffledRdd1.collect().toSet === Set((1, 8), (2, 1))) + + // Now run the same mergeCombine function with map-side combine on. We + // expect to see an exception thrown. + val aggregatorWithException1 = new Aggregator[Int, Int, Int]( + (v: Int) => v, _+_, ShuffleSuite.mergeCombineException) + val shuffledRdd2 = new ShuffledRDD( + pairs, aggregatorWithException1, new HashPartitioner(2)) + evaluating { shuffledRdd2.collect() } should produce [SparkException] + } +} + +object ShuffleSuite { + def mergeCombineException(x: Int, y: Int): Int = { + throw new SparkException("Exception for map-side combine.") + x + y + } }