From 5d50d4f0f9db3e6cc7c51e35cdb2d12daa4fd108 Mon Sep 17 00:00:00 2001
From: Kay Ousterhout <kayousterhout@gmail.com>
Date: Tue, 14 Jun 2016 17:26:33 -0700
Subject: [PATCH] [SPARK-15927] Eliminate redundant DAGScheduler code.

To try to eliminate redundant code to traverse the RDD dependency graph,
this PR creates a new function getShuffleDependencies that returns
shuffle dependencies that are immediate parents of a given RDD.  This
new function is used by getParentStages and
getAncestorShuffleDependencies.

Author: Kay Ousterhout <kayousterhout@gmail.com>

Closes #13646 from kayousterhout/SPARK-15927.
---
 .../apache/spark/scheduler/DAGScheduler.scala | 82 ++++++++++---------
 .../spark/scheduler/DAGSchedulerSuite.scala   | 31 +++++++
 2 files changed, 74 insertions(+), 39 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index a2eadbcbd6..4e1250a14d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -378,59 +378,63 @@ class DAGScheduler(
    * the provided firstJobId.
    */
   private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
-    val parents = new HashSet[Stage]
-    val visited = new HashSet[RDD[_]]
-    // We are manually maintaining a stack here to prevent StackOverflowError
-    // caused by recursively visiting
-    val waitingForVisit = new Stack[RDD[_]]
-    def visit(r: RDD[_]) {
-      if (!visited(r)) {
-        visited += r
-        // Kind of ugly: need to register RDDs with the cache here since
-        // we can't do it in its constructor because # of partitions is unknown
-        for (dep <- r.dependencies) {
-          dep match {
-            case shufDep: ShuffleDependency[_, _, _] =>
-              parents += getShuffleMapStage(shufDep, firstJobId)
-            case _ =>
-              waitingForVisit.push(dep.rdd)
-          }
-        }
-      }
-    }
-    waitingForVisit.push(rdd)
-    while (waitingForVisit.nonEmpty) {
-      visit(waitingForVisit.pop())
-    }
-    parents.toList
+    getShuffleDependencies(rdd).map { shuffleDep =>
+      getShuffleMapStage(shuffleDep, firstJobId)
+    }.toList
   }
 
   /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */
   private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
-    val parents = new Stack[ShuffleDependency[_, _, _]]
+    val ancestors = new Stack[ShuffleDependency[_, _, _]]
     val visited = new HashSet[RDD[_]]
     // We are manually maintaining a stack here to prevent StackOverflowError
     // caused by recursively visiting
     val waitingForVisit = new Stack[RDD[_]]
-    def visit(r: RDD[_]) {
-      if (!visited(r)) {
-        visited += r
-        for (dep <- r.dependencies) {
-          dep match {
-            case shufDep: ShuffleDependency[_, _, _] =>
-              if (!shuffleToMapStage.contains(shufDep.shuffleId)) {
-                parents.push(shufDep)
-              }
-            case _ =>
-          }
-          waitingForVisit.push(dep.rdd)
+    waitingForVisit.push(rdd)
+    while (waitingForVisit.nonEmpty) {
+      val toVisit = waitingForVisit.pop()
+      if (!visited(toVisit)) {
+        visited += toVisit
+        getShuffleDependencies(toVisit).foreach { shuffleDep =>
+          if (!shuffleToMapStage.contains(shuffleDep.shuffleId)) {
+            ancestors.push(shuffleDep)
+            waitingForVisit.push(shuffleDep.rdd)
+          } // Otherwise, the dependency and its ancestors have already been registered.
         }
       }
     }
+    ancestors
+  }
 
+  /**
+   * Returns shuffle dependencies that are immediate parents of the given RDD.
+   *
+   * This function will not return more distant ancestors.  For example, if C has a shuffle
+   * dependency on B which has a shuffle dependency on A:
+   *
+   * A <-- B <-- C
+   *
+   * calling this function with rdd C will only return the B <-- C dependency.
+   *
+   * This function is scheduler-visible for the purpose of unit testing.
+   */
+  private[scheduler] def getShuffleDependencies(
+      rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = {
+    val parents = new HashSet[ShuffleDependency[_, _, _]]
+    val visited = new HashSet[RDD[_]]
+    val waitingForVisit = new Stack[RDD[_]]
     waitingForVisit.push(rdd)
     while (waitingForVisit.nonEmpty) {
-      visit(waitingForVisit.pop())
+      val toVisit = waitingForVisit.pop()
+      if (!visited(toVisit)) {
+        visited += toVisit
+        toVisit.dependencies.foreach {
+          case shuffleDep: ShuffleDependency[_, _, _] =>
+            parents += shuffleDep
+          case dependency =>
+            waitingForVisit.push(dependency.rdd)
+        }
+      }
     }
     parents
   }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 3c30ec8ee8..ab8e95314f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -2023,6 +2023,37 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
     assertDataStructuresEmpty()
   }
 
+  /**
+   * Checks the DAGScheduler's internal logic for traversing a RDD DAG by making sure that
+   * getShuffleDependencies correctly returns the direct shuffle dependencies of a particular
+   * RDD. The test creates the following RDD graph (where n denotes a narrow dependency and s
+   * denotes a shuffle dependency):
+   *
+   * A <------------s---------,
+   *                           \
+   * B <--s-- C <--s-- D <--n---`-- E
+   *
+   * Here, the direct shuffle dependency of C is just the shuffle dependency on B. The direct
+   * shuffle dependencies of E are the shuffle dependency on A and the shuffle dependency on C.
+   */
+  test("getShuffleDependencies correctly returns only direct shuffle parents") {
+    val rddA = new MyRDD(sc, 2, Nil)
+    val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(1))
+    val rddB = new MyRDD(sc, 2, Nil)
+    val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(1))
+    val rddC = new MyRDD(sc, 1, List(shuffleDepB))
+    val shuffleDepC = new ShuffleDependency(rddC, new HashPartitioner(1))
+    val rddD = new MyRDD(sc, 1, List(shuffleDepC))
+    val narrowDepD = new OneToOneDependency(rddD)
+    val rddE = new MyRDD(sc, 1, List(shuffleDepA, narrowDepD), tracker = mapOutputTracker)
+
+    assert(scheduler.getShuffleDependencies(rddA) === Set())
+    assert(scheduler.getShuffleDependencies(rddB) === Set())
+    assert(scheduler.getShuffleDependencies(rddC) === Set(shuffleDepB))
+    assert(scheduler.getShuffleDependencies(rddD) === Set(shuffleDepC))
+    assert(scheduler.getShuffleDependencies(rddE) === Set(shuffleDepA, shuffleDepC))
+  }
+
   /**
    * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
    * Note that this checks only the host and not the executor ID.
-- 
GitLab