diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 7021a339e879b10698b72285c6745373e6cbdab8..658e8c8b8931885c7e242af636c428610519ee3a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -29,15 +29,16 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer}
 import org.apache.spark.util.Utils
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.ShuffleHandle
-
-private[spark] sealed trait CoGroupSplitDep extends Serializable
 
+/** The references to rdd and splitIndex are transient because redundant information is stored
+  * in the CoGroupedRDD object.  Because CoGroupedRDD is serialized separately from
+  * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the
+  * task closure. */
 private[spark] case class NarrowCoGroupSplitDep(
-    rdd: RDD[_],
-    splitIndex: Int,
+    @transient rdd: RDD[_],
+    @transient splitIndex: Int,
     var split: Partition
-  ) extends CoGroupSplitDep {
+  ) extends Serializable {
 
   @throws(classOf[IOException])
   private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
@@ -47,9 +48,16 @@ private[spark] case class NarrowCoGroupSplitDep(
   }
 }
 
-private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep
-
-private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
+/**
+ * Stores information about the narrow dependencies used by a CoGroupedRdd.
+ *
+ * @param narrowDeps maps to the dependencies variable in the parent RDD: for each one to one
+ *                   dependency in dependencies, narrowDeps has a NarrowCoGroupSplitDep (describing
+ *                   the partition for that dependency) at the corresponding index. The size of
+ *                   narrowDeps should always be equal to the number of parents.
+ */
+private[spark] class CoGroupPartition(
+    idx: Int, val narrowDeps: Array[Option[NarrowCoGroupSplitDep]])
   extends Partition with Serializable {
   override val index: Int = idx
   override def hashCode(): Int = idx
@@ -105,9 +113,9 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
         // Assume each RDD contributed a single dependency, and get it
         dependencies(j) match {
           case s: ShuffleDependency[_, _, _] =>
-            new ShuffleCoGroupSplitDep(s.shuffleHandle)
+            None
           case _ =>
-            new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
+            Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)))
         }
       }.toArray)
     }
@@ -120,20 +128,21 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
     val sparkConf = SparkEnv.get.conf
     val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true)
     val split = s.asInstanceOf[CoGroupPartition]
-    val numRdds = split.deps.length
+    val numRdds = dependencies.length
 
     // A list of (rdd iterator, dependency number) pairs
     val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
-    for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
-      case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
+    for ((dep, depNum) <- dependencies.zipWithIndex) dep match {
+      case oneToOneDependency: OneToOneDependency[Product2[K, Any]] =>
+        val dependencyPartition = split.narrowDeps(depNum).get.split
         // Read them from the parent
-        val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
+        val it = oneToOneDependency.rdd.iterator(dependencyPartition, context)
         rddIterators += ((it, depNum))
 
-      case ShuffleCoGroupSplitDep(handle) =>
+      case shuffleDependency: ShuffleDependency[_, _, _] =>
         // Read map outputs of shuffle
         val it = SparkEnv.get.shuffleManager
-          .getReader(handle, split.index, split.index + 1, context)
+          .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context)
           .read()
         rddIterators += ((it, depNum))
     }
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index e9d745588ee9a2b767305c581a953793c4330b1f..633aeba3bbae6424cbedc5bb632d2442d10480dd 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -81,9 +81,9 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
       array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) =>
         dependencies(j) match {
           case s: ShuffleDependency[_, _, _] =>
-            new ShuffleCoGroupSplitDep(s.shuffleHandle)
+            None
           case _ =>
-            new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
+            Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)))
         }
       }.toArray)
     }
@@ -105,20 +105,26 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
         seq
       }
     }
-    def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit): Unit = dep match {
-      case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
-        rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
+    def integrate(depNum: Int, op: Product2[K, V] => Unit) = {
+      dependencies(depNum) match {
+        case oneToOneDependency: OneToOneDependency[_] =>
+          val dependencyPartition = partition.narrowDeps(depNum).get.split
+          oneToOneDependency.rdd.iterator(dependencyPartition, context)
+            .asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
 
-      case ShuffleCoGroupSplitDep(handle) =>
-        val iter = SparkEnv.get.shuffleManager
-          .getReader(handle, partition.index, partition.index + 1, context)
-          .read()
-        iter.foreach(op)
+        case shuffleDependency: ShuffleDependency[_, _, _] =>
+          val iter = SparkEnv.get.shuffleManager
+            .getReader(
+              shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context)
+            .read()
+          iter.foreach(op)
+      }
     }
+
     // the first dep is rdd1; add all values to the map
-    integrate(partition.deps(0), t => getSeq(t._1) += t._2)
+    integrate(0, t => getSeq(t._1) += t._2)
     // the second dep is rdd2; remove all of its keys
-    integrate(partition.deps(1), t => map.remove(t._1))
+    integrate(1, t => map.remove(t._1))
     map.iterator.map { t =>  t._2.iterator.map { (t._1, _) } }.flatten
   }