From 5054abd41b4bac4b7c8159dc23c7ee15aeb7ef2a Mon Sep 17 00:00:00 2001
From: Reynold Xin <reynoldx@gmail.com>
Date: Mon, 19 Aug 2013 12:58:02 -0700
Subject: [PATCH] Code review feedback. (added tests for cogroup and substract;
 added more documentation on MutablePair)

---
 .../main/scala/spark/PairRDDFunctions.scala   |  4 +-
 .../main/scala/spark/util/MutablePair.scala   | 16 +++----
 core/src/test/scala/spark/ShuffleSuite.scala  | 42 ++++++++++++++++++-
 3 files changed, 51 insertions(+), 11 deletions(-)

diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index f8900d3921..e7d4a7f562 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -165,7 +165,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
 
     def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
       val map = new JHashMap[K, V]
-      for ((k, v) <- iter) {
+      iter.foreach { case (k, v) =>
         val old = map.get(k)
         map.put(k, if (old == null) v else func(old, v))
       }
@@ -173,7 +173,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
     }
 
     def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = {
-      for ((k, v) <- m2) {
+      m2.foreach { case (k, v) =>
         val old = m1.get(k)
         m1.put(k, if (old == null) v else func(old, v))
       }
diff --git a/core/src/main/scala/spark/util/MutablePair.scala b/core/src/main/scala/spark/util/MutablePair.scala
index 117218bf47..3063806e83 100644
--- a/core/src/main/scala/spark/util/MutablePair.scala
+++ b/core/src/main/scala/spark/util/MutablePair.scala
@@ -18,17 +18,19 @@
 package spark.util
 
 
-/** A tuple of 2 elements.
-  *  @param  _1   Element 1 of this MutablePair
-  *  @param  _2   Element 2 of this MutablePair
-  */
+/**
+ * A tuple of 2 elements. This can be used as an alternative to Scala's Tuple2 when we want to
+ * minimize object allocation.
+ *
+ * @param  _1   Element 1 of this MutablePair
+ * @param  _2   Element 2 of this MutablePair
+ */
 case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1,
                       @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2]
-  (var _1: T1,var _2: T2)
+  (var _1: T1, var _2: T2)
   extends Product2[T1, T2]
 {
-
   override def toString = "(" + _1 + "," + _2 + ")"
 
-  def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[T1, T2]]
+  override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[T1, T2]]
 }
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index f1361546a3..8745689c70 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -22,8 +22,7 @@ import org.scalatest.matchers.ShouldMatchers
 
 import spark.SparkContext._
 import spark.ShuffleSuite.NonJavaSerializableClass
-import spark.rdd.OrderedRDDFunctions
-import spark.rdd.ShuffledRDD
+import spark.rdd.{SubtractedRDD, CoGroupedRDD, OrderedRDDFunctions, ShuffledRDD}
 import spark.util.MutablePair
 
 
@@ -159,6 +158,45 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
     results(2) should be (p(3, 33))
     results(3) should be (p(100, 100))
   }
+
+  test("cogroup using mutable pairs") {
+    // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+    sc = new SparkContext("local-cluster[2,1,512]", "test")
+    def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
+    val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
+    val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3"))
+    val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2)
+    val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2)
+    val results = new CoGroupedRDD[Int](Seq(pairs1, pairs2), new HashPartitioner(2)).collectAsMap()
+
+    assert(results(1)(0).length === 3)
+    assert(results(1)(0).contains(1))
+    assert(results(1)(0).contains(2))
+    assert(results(1)(0).contains(3))
+    assert(results(1)(1).length === 2)
+    assert(results(1)(1).contains("11"))
+    assert(results(1)(1).contains("12"))
+    assert(results(2)(0).length === 1)
+    assert(results(2)(0).contains(1))
+    assert(results(2)(1).length === 1)
+    assert(results(2)(1).contains("22"))
+    assert(results(3)(0).length === 0)
+    assert(results(3)(1).contains("3"))
+  }
+
+  test("subtract mutable pairs") {
+    // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+    sc = new SparkContext("local-cluster[2,1,512]", "test")
+    def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
+    val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33))
+    val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"))
+    val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2)
+    val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2)
+    val results = new SubtractedRDD(pairs1, pairs2, new HashPartitioner(2)).collect()
+    results should have length (1)
+    // substracted rdd return results as Tuple2
+    results(0) should be ((3, 33))
+  }
 }
 
 object ShuffleSuite {
-- 
GitLab