diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index fa9df3a97e72e66cb735afc864582f38300e39ec..0be4b4feb8c360e9812c0d7d7db118f713ab8dcf 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -85,17 +85,18 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
     }
     val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
     if (self.partitioner == Some(partitioner)) {
-      self.mapPartitions(aggregator.combineValuesByKey(_), true)
+      self.mapPartitions(aggregator.combineValuesByKey, true)
     } else if (mapSideCombine) {
-      val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
-      val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass)
-      partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
+      val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey, true)
+      val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
+        .setSerializer(serializerClass)
+      partitioned.mapPartitions(aggregator.combineCombinersByKey, true)
     } else {
       // Don't apply map-side combiner.
       // A sanity check to make sure mergeCombiners is not defined.
       assert(mergeCombiners == null)
-      val values = new ShuffledRDD[K, V](self, partitioner, serializerClass)
-      values.mapPartitions(aggregator.combineValuesByKey(_), true)
+      val values = new ShuffledRDD[K, V](self, partitioner).setSerializer(serializerClass)
+      values.mapPartitions(aggregator.combineValuesByKey, true)
     }
   }
 
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 019b12d2d5e43796c52ff6bc1ff165642abbe9fd..c2d95dc06049f8b603abbf9373422dfc4739c882 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -60,12 +60,16 @@ class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
  * @param rdds parent RDDs.
  * @param part partitioner used to partition the shuffle output.
  */
-class CoGroupedRDD[K](
-  @transient var rdds: Seq[RDD[(K, _)]],
-  part: Partitioner,
-  val serializerClass: String = null)
+class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
   extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
 
+  private var serializerClass: String = null
+
+  def setSerializer(cls: String): CoGroupedRDD[K] = {
+    serializerClass = cls
+    this
+  }
+
   override def getDependencies: Seq[Dependency[_]] = {
     rdds.map { rdd: RDD[(K, _)] =>
       if (rdd.partitioner == Some(part)) {
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 0137f80953e02a14e537f34539b4f6710682c4f8..bcf7d0d89c4451a1cb15457316a1fdb91fa8a788 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -17,8 +17,9 @@
 
 package spark.rdd
 
-import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext}
-import spark.SparkContext._
+import spark._
+import scala.Some
+import scala.Some
 
 
 private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
@@ -30,15 +31,24 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
  * The resulting RDD from a shuffle (e.g. repartitioning of data).
  * @param prev the parent RDD.
  * @param part the partitioner used to partition the RDD
- * @param serializerClass class name of the serializer to use.
  * @tparam K the key class.
  * @tparam V the value class.
  */
 class ShuffledRDD[K, V](
-    @transient prev: RDD[(K, V)],
-    part: Partitioner,
-    serializerClass: String = null)
-  extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part, serializerClass))) {
+    @transient var prev: RDD[(K, V)],
+    part: Partitioner)
+  extends RDD[(K, V)](prev.context, Nil) {
+
+  private var serializerClass: String = null
+
+  def setSerializer(cls: String): ShuffledRDD[K, V] = {
+    serializerClass = cls
+    this
+  }
+
+  override def getDependencies: Seq[Dependency[_]] = {
+    List(new ShuffleDependency(prev, part, serializerClass))
+  }
 
   override val partitioner = Some(part)
 
@@ -51,4 +61,9 @@ class ShuffledRDD[K, V](
     SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics,
       SparkEnv.get.serializerManager.get(serializerClass))
   }
+
+  override def clearDependencies() {
+    super.clearDependencies()
+    prev = null
+  }
 }
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 0402b9f2501d62e2fa9e49aa096ee193926519be..46b8cafaac3396f2dcf39f87e921b3de19001ca2 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -49,10 +49,16 @@ import spark.OneToOneDependency
 private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
     @transient var rdd1: RDD[(K, V)],
     @transient var rdd2: RDD[(K, W)],
-    part: Partitioner,
-    val serializerClass: String = null)
+    part: Partitioner)
   extends RDD[(K, V)](rdd1.context, Nil) {
 
+  private var serializerClass: String = null
+
+  def setSerializer(cls: String): SubtractedRDD[K, V, W] = {
+    serializerClass = cls
+    this
+  }
+
   override def getDependencies: Seq[Dependency[_]] = {
     Seq(rdd1, rdd2).map { rdd =>
       if (rdd.partitioner == Some(part)) {
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 752e4b85e67d2585f128a9e76b754aa49cafee1a..c686b8cc5af43a04dddad6d9256b9339b947d0e8 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -17,17 +17,8 @@
 
 package spark
 
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashSet
-
 import org.scalatest.FunSuite
 import org.scalatest.matchers.ShouldMatchers
-import org.scalatest.prop.Checkers
-import org.scalacheck.Arbitrary._
-import org.scalacheck.Gen
-import org.scalacheck.Prop._
-
-import com.google.common.io.Files
 
 import spark.rdd.ShuffledRDD
 import spark.SparkContext._
@@ -59,8 +50,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
     }
     // If the Kryo serializer is not used correctly, the shuffle would fail because the
     // default Java serializer cannot handle the non serializable class.
-    val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS),
-      classOf[spark.KryoSerializer].getName)
+    val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS))
+      .setSerializer(classOf[spark.KryoSerializer].getName)
     val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
 
     assert(c.count === 10)
@@ -81,7 +72,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
     }
     // If the Kryo serializer is not used correctly, the shuffle would fail because the
     // default Java serializer cannot handle the non serializable class.
-    val c = new ShuffledRDD(b, new HashPartitioner(3), classOf[spark.KryoSerializer].getName)
+    val c = new ShuffledRDD(b, new HashPartitioner(3))
+      .setSerializer(classOf[spark.KryoSerializer].getName)
     assert(c.count === 10)
   }
 
@@ -96,7 +88,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
 
     // NOTE: The default Java serializer doesn't create zero-sized blocks.
     //       So, use Kryo
-    val c = new ShuffledRDD(b, new HashPartitioner(10), classOf[spark.KryoSerializer].getName)
+    val c = new ShuffledRDD(b, new HashPartitioner(10))
+      .setSerializer(classOf[spark.KryoSerializer].getName)
 
     val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
     assert(c.count === 4)