diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 4c4151933086f145df6fef8328608c641a741cd4..112beb2320a06fb9485b6e2df0677eeb83d29c9b 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -62,7 +62,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
     }
     val aggregator =
       new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
-    if (Option(partitioner) == self.partitioner) {
+    if (self.partitioner == Some(partitioner)) {
       self.mapPartitions(aggregator.combineValuesByKey(_), true)
     } else if (mapSideCombine) {
       val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index d6efa3db43da3a18e7f91e807b33eabaebe4de24..50f2b294bfb059716da99ec34b1a018603788eaa 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -1,6 +1,7 @@
 package spark
 
 import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
 
 import org.scalatest.FunSuite
 import org.scalatest.matchers.ShouldMatchers
@@ -105,11 +106,20 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
       def numPartitions = 2
       def getPartition(key: Any) = key.asInstanceOf[Int]
     }
-    val pairs = rddToPairRDDFunctions(sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1)))).partitionBy(p)
-    val sums = pairs.reduceByKey(p, _+_)
-    println(sums.toDebugString)
+    val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
+    val sums = pairs.reduceByKey(_+_)
     assert(sums.collect().toSet === Set((1, 4), (0, 1)))
     assert(sums.partitioner === Some(p))
+    // count the dependencies to make sure there is only 1 ShuffledRDD
+    val deps = new HashSet[RDD[_]]()
+    def visit(r: RDD[_]) {
+      for (dep <- r.dependencies) {
+        deps += dep.rdd
+        visit(dep.rdd)
+      }
+    }
+    visit(sums)
+    assert(deps.size === 2) // ShuffledRDD, ParallelCollection
   }
 
   test("join") {