diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index a2eadbcbd660a81f41ba150e88af05e588a43b52..4e1250a14d7ca4f1552ada2a81051b6cafd630c0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -378,59 +378,63 @@ class DAGScheduler(
    * the provided firstJobId.
    */
   private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
-    val parents = new HashSet[Stage]
-    val visited = new HashSet[RDD[_]]
-    // We are manually maintaining a stack here to prevent StackOverflowError
-    // caused by recursively visiting
-    val waitingForVisit = new Stack[RDD[_]]
-    def visit(r: RDD[_]) {
-      if (!visited(r)) {
-        visited += r
-        // Kind of ugly: need to register RDDs with the cache here since
-        // we can't do it in its constructor because # of partitions is unknown
-        for (dep <- r.dependencies) {
-          dep match {
-            case shufDep: ShuffleDependency[_, _, _] =>
-              parents += getShuffleMapStage(shufDep, firstJobId)
-            case _ =>
-              waitingForVisit.push(dep.rdd)
-          }
-        }
-      }
-    }
-    waitingForVisit.push(rdd)
-    while (waitingForVisit.nonEmpty) {
-      visit(waitingForVisit.pop())
-    }
-    parents.toList
+    getShuffleDependencies(rdd).map { shuffleDep =>
+      getShuffleMapStage(shuffleDep, firstJobId)
+    }.toList
   }
 
   /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */
   private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
-    val parents = new Stack[ShuffleDependency[_, _, _]]
+    val ancestors = new Stack[ShuffleDependency[_, _, _]]
     val visited = new HashSet[RDD[_]]
     // We are manually maintaining a stack here to prevent StackOverflowError
     // caused by recursively visiting
     val waitingForVisit = new Stack[RDD[_]]
-    def visit(r: RDD[_]) {
-      if (!visited(r)) {
-        visited += r
-        for (dep <- r.dependencies) {
-          dep match {
-            case shufDep: ShuffleDependency[_, _, _] =>
-              if (!shuffleToMapStage.contains(shufDep.shuffleId)) {
-                parents.push(shufDep)
-              }
-            case _ =>
-          }
-          waitingForVisit.push(dep.rdd)
+    waitingForVisit.push(rdd)
+    while (waitingForVisit.nonEmpty) {
+      val toVisit = waitingForVisit.pop()
+      if (!visited(toVisit)) {
+        visited += toVisit
+        getShuffleDependencies(toVisit).foreach { shuffleDep =>
+          if (!shuffleToMapStage.contains(shuffleDep.shuffleId)) {
+            ancestors.push(shuffleDep)
+            waitingForVisit.push(shuffleDep.rdd)
+          } // Otherwise, the dependency and its ancestors have already been registered.
         }
       }
     }
+    ancestors
+  }
 
+  /**
+   * Returns shuffle dependencies that are immediate parents of the given RDD.
+   *
+   * This function will not return more distant ancestors.  For example, if C has a shuffle
+   * dependency on B which has a shuffle dependency on A:
+   *
+   * A <-- B <-- C
+   *
+   * calling this function with rdd C will only return the B <-- C dependency.
+   *
+   * This function is scheduler-visible for the purpose of unit testing.
+   */
+  private[scheduler] def getShuffleDependencies(
+      rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = {
+    val parents = new HashSet[ShuffleDependency[_, _, _]]
+    val visited = new HashSet[RDD[_]]
+    val waitingForVisit = new Stack[RDD[_]]
     waitingForVisit.push(rdd)
     while (waitingForVisit.nonEmpty) {
-      visit(waitingForVisit.pop())
+      val toVisit = waitingForVisit.pop()
+      if (!visited(toVisit)) {
+        visited += toVisit
+        toVisit.dependencies.foreach {
+          case shuffleDep: ShuffleDependency[_, _, _] =>
+            parents += shuffleDep
+          case dependency =>
+            waitingForVisit.push(dependency.rdd)
+        }
+      }
     }
     parents
   }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 3c30ec8ee8e32ac13fb7ee8e870095c6e97ad8e3..ab8e95314fdf3dd2b0a3b477ca7c9f24798c8b60 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -2023,6 +2023,37 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
     assertDataStructuresEmpty()
   }
 
+  /**
+   * Checks the DAGScheduler's internal logic for traversing a RDD DAG by making sure that
+   * getShuffleDependencies correctly returns the direct shuffle dependencies of a particular
+   * RDD. The test creates the following RDD graph (where n denotes a narrow dependency and s
+   * denotes a shuffle dependency):
+   *
+   * A <------------s---------,
+   *                           \
+   * B <--s-- C <--s-- D <--n---`-- E
+   *
+   * Here, the direct shuffle dependency of C is just the shuffle dependency on B. The direct
+   * shuffle dependencies of E are the shuffle dependency on A and the shuffle dependency on C.
+   */
+  test("getShuffleDependencies correctly returns only direct shuffle parents") {
+    val rddA = new MyRDD(sc, 2, Nil)
+    val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(1))
+    val rddB = new MyRDD(sc, 2, Nil)
+    val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(1))
+    val rddC = new MyRDD(sc, 1, List(shuffleDepB))
+    val shuffleDepC = new ShuffleDependency(rddC, new HashPartitioner(1))
+    val rddD = new MyRDD(sc, 1, List(shuffleDepC))
+    val narrowDepD = new OneToOneDependency(rddD)
+    val rddE = new MyRDD(sc, 1, List(shuffleDepA, narrowDepD), tracker = mapOutputTracker)
+
+    assert(scheduler.getShuffleDependencies(rddA) === Set())
+    assert(scheduler.getShuffleDependencies(rddB) === Set())
+    assert(scheduler.getShuffleDependencies(rddC) === Set(shuffleDepB))
+    assert(scheduler.getShuffleDependencies(rddD) === Set(shuffleDepC))
+    assert(scheduler.getShuffleDependencies(rddE) === Set(shuffleDepA, shuffleDepC))
+  }
+
   /**
    * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
    * Note that this checks only the host and not the executor ID.