diff --git a/src/test/spark/ShuffleSuite.scala b/src/test/spark/ShuffleSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..2898bd09c807be3e0c63e336f96120a31abf1dfa
--- /dev/null
+++ b/src/test/spark/ShuffleSuite.scala
@@ -0,0 +1,119 @@
+package spark
+
+import org.scalatest.FunSuite
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+
+import SparkContext._
+
+class ShuffleSuite extends FunSuite {
+  test("groupByKey") {
+    val sc = new SparkContext("local", "test")
+    val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+    val groups = pairs.groupByKey().collect()
+    assert(groups.size === 2)
+    val valuesFor1 = groups.find(_._1 == 1).get._2
+    assert(valuesFor1.toList.sorted === List(1, 2, 3))
+    val valuesFor2 = groups.find(_._1 == 2).get._2
+    assert(valuesFor2.toList.sorted === List(1))
+  }
+
+  test("groupByKey with duplicates") {
+    val sc = new SparkContext("local", "test")
+    val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+    val groups = pairs.groupByKey().collect()
+    assert(groups.size === 2)
+    val valuesFor1 = groups.find(_._1 == 1).get._2
+    assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
+    val valuesFor2 = groups.find(_._1 == 2).get._2
+    assert(valuesFor2.toList.sorted === List(1))
+  }
+
+  test("groupByKey with many output partitions") {
+    val sc = new SparkContext("local", "test")
+    val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+    val groups = pairs.groupByKey(10).collect()
+    assert(groups.size === 2)
+    val valuesFor1 = groups.find(_._1 == 1).get._2
+    assert(valuesFor1.toList.sorted === List(1, 2, 3))
+    val valuesFor2 = groups.find(_._1 == 2).get._2
+    assert(valuesFor2.toList.sorted === List(1))
+  }
+
+  test("reduceByKey") {
+    val sc = new SparkContext("local", "test")
+    val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+    val sums = pairs.reduceByKey(_+_).collect()
+    assert(sums.toSet === Set((1, 7), (2, 1)))
+  }
+
+  test("reduceByKey with collectAsMap") {
+    val sc = new SparkContext("local", "test")
+    val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+    val sums = pairs.reduceByKey(_+_).collectAsMap()
+    assert(sums.size === 2)
+    assert(sums(1) === 7)
+    assert(sums(2) === 1)
+  }
+
+  test("reduceByKey with many output partitons") {
+    val sc = new SparkContext("local", "test")
+    val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+    val sums = pairs.reduceByKey(_+_, 10).collect()
+    assert(sums.toSet === Set((1, 7), (2, 1)))
+  }
+
+  test("join") {
+    val sc = new SparkContext("local", "test")
+    val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+    val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+    val joined = rdd1.join(rdd2).collect()
+    assert(joined.size === 4)
+    assert(joined.toSet === Set(
+      (1, (1, 'x')),
+      (1, (2, 'x')),
+      (2, (1, 'y')),
+      (2, (1, 'z'))
+    ))
+  }
+
+  test("join all-to-all") {
+    val sc = new SparkContext("local", "test")
+    val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
+    val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
+    val joined = rdd1.join(rdd2).collect()
+    assert(joined.size === 6)
+    assert(joined.toSet === Set(
+      (1, (1, 'x')),
+      (1, (1, 'y')),
+      (1, (2, 'x')),
+      (1, (2, 'y')),
+      (1, (3, 'x')),
+      (1, (3, 'y'))
+    ))
+  }
+
+  test("join with no matches") {
+    val sc = new SparkContext("local", "test")
+    val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+    val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
+    val joined = rdd1.join(rdd2).collect()
+    assert(joined.size === 0)
+  }
+
+  test("join with many output partitions") {
+    val sc = new SparkContext("local", "test")
+    val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+    val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+    val joined = rdd1.join(rdd2, 10).collect()
+    assert(joined.size === 4)
+    assert(joined.toSet === Set(
+      (1, (1, 'x')),
+      (1, (2, 'x')),
+      (2, (1, 'y')),
+      (2, (1, 'z'))
+    ))
+  }
+}