diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a2eadbcbd660a81f41ba150e88af05e588a43b52..4e1250a14d7ca4f1552ada2a81051b6cafd630c0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -378,59 +378,63 @@ class DAGScheduler( * the provided firstJobId. */ private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = { - val parents = new HashSet[Stage] - val visited = new HashSet[RDD[_]] - // We are manually maintaining a stack here to prevent StackOverflowError - // caused by recursively visiting - val waitingForVisit = new Stack[RDD[_]] - def visit(r: RDD[_]) { - if (!visited(r)) { - visited += r - // Kind of ugly: need to register RDDs with the cache here since - // we can't do it in its constructor because # of partitions is unknown - for (dep <- r.dependencies) { - dep match { - case shufDep: ShuffleDependency[_, _, _] => - parents += getShuffleMapStage(shufDep, firstJobId) - case _ => - waitingForVisit.push(dep.rdd) - } - } - } - } - waitingForVisit.push(rdd) - while (waitingForVisit.nonEmpty) { - visit(waitingForVisit.pop()) - } - parents.toList + getShuffleDependencies(rdd).map { shuffleDep => + getShuffleMapStage(shuffleDep, firstJobId) + }.toList } /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { - val parents = new Stack[ShuffleDependency[_, _, _]] + val ancestors = new Stack[ShuffleDependency[_, _, _]] val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError // caused by recursively visiting val waitingForVisit = new Stack[RDD[_]] - def visit(r: RDD[_]) { - if (!visited(r)) { - visited += r - for (dep <- r.dependencies) { - dep match { - case shufDep: ShuffleDependency[_, _, _] => - if (!shuffleToMapStage.contains(shufDep.shuffleId)) { - parents.push(shufDep) - } - case _ => - } - waitingForVisit.push(dep.rdd) + waitingForVisit.push(rdd) + while (waitingForVisit.nonEmpty) { + val toVisit = waitingForVisit.pop() + if (!visited(toVisit)) { + visited += toVisit + getShuffleDependencies(toVisit).foreach { shuffleDep => + if (!shuffleToMapStage.contains(shuffleDep.shuffleId)) { + ancestors.push(shuffleDep) + waitingForVisit.push(shuffleDep.rdd) + } // Otherwise, the dependency and its ancestors have already been registered. } } } + ancestors + } + /** + * Returns shuffle dependencies that are immediate parents of the given RDD. + * + * This function will not return more distant ancestors. For example, if C has a shuffle + * dependency on B which has a shuffle dependency on A: + * + * A <-- B <-- C + * + * calling this function with rdd C will only return the B <-- C dependency. + * + * This function is scheduler-visible for the purpose of unit testing. + */ + private[scheduler] def getShuffleDependencies( + rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = { + val parents = new HashSet[ShuffleDependency[_, _, _]] + val visited = new HashSet[RDD[_]] + val waitingForVisit = new Stack[RDD[_]] waitingForVisit.push(rdd) while (waitingForVisit.nonEmpty) { - visit(waitingForVisit.pop()) + val toVisit = waitingForVisit.pop() + if (!visited(toVisit)) { + visited += toVisit + toVisit.dependencies.foreach { + case shuffleDep: ShuffleDependency[_, _, _] => + parents += shuffleDep + case dependency => + waitingForVisit.push(dependency.rdd) + } + } } parents } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 3c30ec8ee8e32ac13fb7ee8e870095c6e97ad8e3..ab8e95314fdf3dd2b0a3b477ca7c9f24798c8b60 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2023,6 +2023,37 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assertDataStructuresEmpty() } + /** + * Checks the DAGScheduler's internal logic for traversing a RDD DAG by making sure that + * getShuffleDependencies correctly returns the direct shuffle dependencies of a particular + * RDD. The test creates the following RDD graph (where n denotes a narrow dependency and s + * denotes a shuffle dependency): + * + * A <------------s---------, + * \ + * B <--s-- C <--s-- D <--n---`-- E + * + * Here, the direct shuffle dependency of C is just the shuffle dependency on B. The direct + * shuffle dependencies of E are the shuffle dependency on A and the shuffle dependency on C. + */ + test("getShuffleDependencies correctly returns only direct shuffle parents") { + val rddA = new MyRDD(sc, 2, Nil) + val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(1)) + val rddB = new MyRDD(sc, 2, Nil) + val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(1)) + val rddC = new MyRDD(sc, 1, List(shuffleDepB)) + val shuffleDepC = new ShuffleDependency(rddC, new HashPartitioner(1)) + val rddD = new MyRDD(sc, 1, List(shuffleDepC)) + val narrowDepD = new OneToOneDependency(rddD) + val rddE = new MyRDD(sc, 1, List(shuffleDepA, narrowDepD), tracker = mapOutputTracker) + + assert(scheduler.getShuffleDependencies(rddA) === Set()) + assert(scheduler.getShuffleDependencies(rddB) === Set()) + assert(scheduler.getShuffleDependencies(rddC) === Set(shuffleDepB)) + assert(scheduler.getShuffleDependencies(rddD) === Set(shuffleDepC)) + assert(scheduler.getShuffleDependencies(rddE) === Set(shuffleDepA, shuffleDepC)) + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID.