From 5054abd41b4bac4b7c8159dc23c7ee15aeb7ef2a Mon Sep 17 00:00:00 2001 From: Reynold Xin <reynoldx@gmail.com> Date: Mon, 19 Aug 2013 12:58:02 -0700 Subject: [PATCH] Code review feedback. (added tests for cogroup and substract; added more documentation on MutablePair) --- .../main/scala/spark/PairRDDFunctions.scala | 4 +- .../main/scala/spark/util/MutablePair.scala | 16 +++---- core/src/test/scala/spark/ShuffleSuite.scala | 42 ++++++++++++++++++- 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index f8900d3921..e7d4a7f562 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -165,7 +165,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { val map = new JHashMap[K, V] - for ((k, v) <- iter) { + iter.foreach { case (k, v) => val old = map.get(k) map.put(k, if (old == null) v else func(old, v)) } @@ -173,7 +173,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) } def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = { - for ((k, v) <- m2) { + m2.foreach { case (k, v) => val old = m1.get(k) m1.put(k, if (old == null) v else func(old, v)) } diff --git a/core/src/main/scala/spark/util/MutablePair.scala b/core/src/main/scala/spark/util/MutablePair.scala index 117218bf47..3063806e83 100644 --- a/core/src/main/scala/spark/util/MutablePair.scala +++ b/core/src/main/scala/spark/util/MutablePair.scala @@ -18,17 +18,19 @@ package spark.util -/** A tuple of 2 elements. - * @param _1 Element 1 of this MutablePair - * @param _2 Element 2 of this MutablePair - */ +/** + * A tuple of 2 elements. This can be used as an alternative to Scala's Tuple2 when we want to + * minimize object allocation. + * + * @param _1 Element 1 of this MutablePair + * @param _2 Element 2 of this MutablePair + */ case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1, @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2] - (var _1: T1,var _2: T2) + (var _1: T1, var _2: T2) extends Product2[T1, T2] { - override def toString = "(" + _1 + "," + _2 + ")" - def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[T1, T2]] + override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[T1, T2]] } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index f1361546a3..8745689c70 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -22,8 +22,7 @@ import org.scalatest.matchers.ShouldMatchers import spark.SparkContext._ import spark.ShuffleSuite.NonJavaSerializableClass -import spark.rdd.OrderedRDDFunctions -import spark.rdd.ShuffledRDD +import spark.rdd.{SubtractedRDD, CoGroupedRDD, OrderedRDDFunctions, ShuffledRDD} import spark.util.MutablePair @@ -159,6 +158,45 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { results(2) should be (p(3, 33)) results(3) should be (p(100, 100)) } + + test("cogroup using mutable pairs") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) + val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) + val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) + val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) + val results = new CoGroupedRDD[Int](Seq(pairs1, pairs2), new HashPartitioner(2)).collectAsMap() + + assert(results(1)(0).length === 3) + assert(results(1)(0).contains(1)) + assert(results(1)(0).contains(2)) + assert(results(1)(0).contains(3)) + assert(results(1)(1).length === 2) + assert(results(1)(1).contains("11")) + assert(results(1)(1).contains("12")) + assert(results(2)(0).length === 1) + assert(results(2)(0).contains(1)) + assert(results(2)(1).length === 1) + assert(results(2)(1).contains("22")) + assert(results(3)(0).length === 0) + assert(results(3)(1).contains("3")) + } + + test("subtract mutable pairs") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) + val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) + val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) + val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) + val results = new SubtractedRDD(pairs1, pairs2, new HashPartitioner(2)).collect() + results should have length (1) + // substracted rdd return results as Tuple2 + results(0) should be ((3, 33)) + } } object ShuffleSuite { -- GitLab