diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
index 6c577c2685597abafab755f41fa91269132ca801..89173540d400923c435d6502381cb0a036c609a4 100644
--- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -4,12 +4,12 @@ import scala.collection.mutable.{Map, HashMap}
 
 import org.scalatest.FunSuite
 import org.scalatest.BeforeAndAfter
-import org.scalatest.concurrent.AsyncAssertions
 import org.scalatest.concurrent.TimeLimitedTests
 import org.scalatest.mock.EasyMockSugar
 import org.scalatest.time.{Span, Seconds}
 
 import org.easymock.EasyMock._
+import org.easymock.Capture
 import org.easymock.EasyMock
 import org.easymock.{IAnswer, IArgumentMatcher}
 
@@ -30,33 +30,55 @@ import spark.TaskEndReason
 
 import spark.{FetchFailed, Success}
 
-class DAGSchedulerSuite extends FunSuite
-    with BeforeAndAfter with EasyMockSugar with TimeLimitedTests
-    with AsyncAssertions with spark.Logging {
+class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests {
 
-  // If we crash the DAGScheduler thread, our test will probably hang.
+  // impose a time limit on this test in case we don't let the job finish.
   override val timeLimit = Span(5, Seconds)
 
   val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite")
   var scheduler: DAGScheduler = null
-  var w: Waiter = null
   val taskScheduler = mock[TaskScheduler]
   val blockManagerMaster = mock[BlockManagerMaster]
   var mapOutputTracker: MapOutputTracker = null
   var schedulerThread: Thread = null
   var schedulerException: Throwable = null
+
+  /** Set of EasyMock argument matchers that match a TaskSet for a given RDD.
+   * We cache these so we do not create duplicate matchers for the same RDD.
+   * This allows us to easily setup a sequence of expectations for task sets for
+   * that RDD.
+   */
   val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher]
+
+  /** Set of cache locations to return from our mock BlockManagerMaster.
+   * Keys are (rdd ID, partition ID). Anything not present will return an empty
+   * list of cache locations silently.
+   */
   val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
 
-  implicit val mocks = MockObjects(taskScheduler, blockManagerMaster)
+  /** JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which
+   * will only submit one job) from needing to explicitly track it.
+   */
+  var lastJobWaiter: JobWaiter = null
 
-  def makeBlockManagerId(host: String): BlockManagerId =
-    BlockManagerId("exec-" + host, host, 12345)
+  /** Tell EasyMockSugar what mock objects we want to be configured by expecting {...}
+   * and whenExecuting {...} */
+  implicit val mocks = MockObjects(taskScheduler, blockManagerMaster)
 
+  /** Utility function to reset mocks and set expectations on them. EasyMock wants mock objects
+   * to be reset after each time their expectations are set, and we tend to check mock object
+   * calls over a single call to DAGScheduler.
+   *
+   * We also set a default expectation here that blockManagerMaster.getLocations can be called
+   * and will return values from cacheLocations.
+   */
   def resetExpecting(f: => Unit) {
     reset(taskScheduler)
     reset(blockManagerMaster)
-    expecting(f)
+    expecting {
+      expectGetLocations()
+      f
+    }
   }
 
   before {
@@ -70,45 +92,30 @@ class DAGSchedulerSuite extends FunSuite
     whenExecuting {
       scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null)
     }
-    w = new Waiter
-    schedulerException = null
-    schedulerThread = new Thread("DAGScheduler under test") {
-      override def run() {
-        try {
-          scheduler.run()
-        } catch {
-          case t: Throwable =>
-            logError("Got exception in DAGScheduler: ", t)
-            schedulerException = t
-        } finally {
-          w.dismiss()
-        }
-      }
-    }
-    schedulerThread.start
-    logInfo("finished before")
   }
 
   after {
-    logInfo("started after")
+    assert(scheduler.processEvent(StopDAGScheduler))
     resetExpecting {
       taskScheduler.stop()
     }
     whenExecuting {
-      scheduler.stop
-      schedulerThread.join
-    }
-    w.await()
-    if (schedulerException != null) {
-      throw new Exception("Exception caught from scheduler thread", schedulerException)
+      scheduler.stop()
     }
     System.clearProperty("spark.master.port")
   }
 
-  // Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
-  // This is a pair RDD type so it can always be used in ShuffleDependencies.
+  def makeBlockManagerId(host: String): BlockManagerId =
+    BlockManagerId("exec-" + host, host, 12345)
+
+  /** Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
+   * This is a pair RDD type so it can always be used in ShuffleDependencies. */
   type MyRDD = RDD[(Int, Int)]
 
+  /** Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+   * preferredLocations (if any) that are passed to them. They are deliberately not executable
+   * so we can test that DAGScheduler does not try to execute RDDs locally.
+   */
   def makeRdd(
         numSplits: Int,
         dependencies: List[Dependency[_]],
@@ -130,6 +137,9 @@ class DAGSchedulerSuite extends FunSuite
     }
   }
 
+  /** EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task
+   * is from a particular RDD.
+   */
   def taskSetForRdd(rdd: MyRDD): TaskSet = {
     val matcher = taskSetMatchers.getOrElseUpdate(rdd,
       new IArgumentMatcher {
@@ -149,6 +159,9 @@ class DAGSchedulerSuite extends FunSuite
     return null
   }
 
+  /** Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from
+   * cacheLocations.
+   */
   def expectGetLocations(): Unit = {
     EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])).
         andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] {
@@ -171,51 +184,106 @@ class DAGSchedulerSuite extends FunSuite
     }).anyTimes()
   }
 
-  def expectStageAnd(rdd: MyRDD, results: Seq[(TaskEndReason, Any)],
-      preferredLocations: Option[Seq[Seq[String]]] = None)(afterSubmit: TaskSet => Unit) {
-    // TODO: Remember which submission
-    EasyMock.expect(taskScheduler.submitTasks(taskSetForRdd(rdd))).andAnswer(new IAnswer[Unit] {
-      override def answer(): Unit = {
-        val taskSet = getCurrentArguments()(0).asInstanceOf[TaskSet]
-        for (task <- taskSet.tasks) {
-          task.generation = mapOutputTracker.getGeneration
-        }
-        afterSubmit(taskSet)
-        preferredLocations match {
-          case None =>
-            for (taskLocs <- taskSet.tasks.map(_.preferredLocations)) {
-              w { assert(taskLocs.size === 0) }
-            }
-          case Some(locations) =>
-            w { assert(locations.size === taskSet.tasks.size) }
-            for ((expectLocs, taskLocs) <-
-                    taskSet.tasks.map(_.preferredLocations).zip(locations)) {
-              w { assert(expectLocs === taskLocs) }
-            }
-        }
-        w { assert(taskSet.tasks.size >= results.size)}
-        for ((result, i) <- results.zipWithIndex) {
-          if (i < taskSet.tasks.size) {
-            scheduler.taskEnded(taskSet.tasks(i), result._1, result._2, Map[Long, Any]())
-          }
-        }
+  /** Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
+   * the scheduler not to exit.
+   *
+   * After processing the event, submit waiting stages as is done on most iterations of the
+   * DAGScheduler event loop.
+   */
+  def runEvent(event: DAGSchedulerEvent) {
+    assert(!scheduler.processEvent(event))
+    scheduler.submitWaitingStages()
+  }
+
+  /** Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be
+   * called from a resetExpecting { ... } block.
+   *
+   * Returns a easymock Capture that will contain the task set after the stage is submitted.
+   * Most tests should use interceptStage() instead of this directly.
+   */
+  def expectStage(rdd: MyRDD): Capture[TaskSet] = {
+    val taskSetCapture = new Capture[TaskSet]
+    taskScheduler.submitTasks(and(capture(taskSetCapture), taskSetForRdd(rdd)))
+    return taskSetCapture
+  }
+
+  /** Expect the supplied code snippet to submit a stage for the specified RDD.
+   * Return the resulting TaskSet. First marks all the tasks are belonging to the
+   * current MapOutputTracker generation.
+   */
+  def interceptStage(rdd: MyRDD)(f: => Unit): TaskSet = {
+    var capture: Capture[TaskSet] = null
+    resetExpecting {
+      capture = expectStage(rdd)
+    }
+    whenExecuting {
+      f
+    }
+    val taskSet = capture.getValue
+    for (task <- taskSet.tasks) {
+      task.generation = mapOutputTracker.getGeneration
+    }
+    return taskSet
+  }
+
+  /** Send the given CompletionEvent messages for the tasks in the TaskSet. */
+  def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
+    assert(taskSet.tasks.size >= results.size)
+    for ((result, i) <- results.zipWithIndex) {
+      if (i < taskSet.tasks.size) {
+        runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()))
       }
-    })
+    }
   }
 
-  def expectStage(rdd: MyRDD, results: Seq[(TaskEndReason, Any)],
-                  preferredLocations: Option[Seq[Seq[String]]] = None) {
-    expectStageAnd(rdd, results, preferredLocations) { _ => }
+  /** Assert that the supplied TaskSet has exactly the given preferredLocations. */
+  def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
+    assert(locations.size === taskSet.tasks.size)
+    for ((expectLocs, taskLocs) <-
+            taskSet.tasks.map(_.preferredLocations).zip(locations)) {
+      assert(expectLocs === taskLocs)
+    }
   }
 
-  def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): Array[Int] = {
-    return scheduler.runJob[(Int, Int), Int](
+  /** When we submit dummy Jobs, this is the compute function we supply. Except in a local test
+   * below, we do not expect this function to ever be executed; instead, we will return results
+   * directly through CompletionEvents.
+   */
+  def jobComputeFunc(context: TaskContext, it: Iterator[(Int, Int)]): Int =
+     it.next._1.asInstanceOf[Int]
+
+
+  /** Start a job to compute the given RDD. Returns the JobWaiter that will
+   * collect the result of the job via callbacks from DAGScheduler. */
+  def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): JobWaiter = {
+    val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int](
         rdd,
-        (context: TaskContext, it: Iterator[(Int, Int)]) => it.next._1.asInstanceOf[Int],
+        jobComputeFunc,
         (0 to (rdd.splits.size - 1)),
         "test-site",
         allowLocal
     )
+    lastJobWaiter = waiter
+    runEvent(toSubmit)
+    return waiter
+  }
+
+  /** Assert that a job we started has failed. */
+  def expectJobException(waiter: JobWaiter = lastJobWaiter) {
+    waiter.getResult match {
+      case JobSucceeded(_) => fail()
+      case JobFailed(_) => return
+    }
+  }
+
+  /** Assert that a job we started has succeeded and has the given result. */
+  def expectJobResult(expected: Array[Int], waiter: JobWaiter = lastJobWaiter) {
+    waiter.getResult match {
+      case JobSucceeded(answer) =>
+        assert(expected === answer.asInstanceOf[Seq[Int]].toArray )
+      case JobFailed(_) =>
+        fail()
+    }
   }
 
   def makeMapStatus(host: String, reduces: Int): MapStatus =
@@ -223,24 +291,14 @@ class DAGSchedulerSuite extends FunSuite
 
   test("zero split job") {
     val rdd = makeRdd(0, Nil)
-    resetExpecting {
-      expectGetLocations()
-      // deliberately expect no stages to be submitted
-    }
-    whenExecuting {
-      assert(submitRdd(rdd) === Array[Int]())
-    }
+    assert(scheduler.runJob(rdd, jobComputeFunc, Seq(), "test-site", false) === Array[Int]())
   }
 
   test("run trivial job") {
     val rdd = makeRdd(1, Nil)
-    resetExpecting {
-      expectGetLocations()
-      expectStage(rdd, List( (Success, 42) ))
-    }
-    whenExecuting {
-      assert(submitRdd(rdd) === Array(42))
-    }
+    val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+    respondToTaskSet(taskSet, List( (Success, 42) ))
+    expectJobResult(Array(42))
   }
 
   test("local job") {
@@ -251,51 +309,34 @@ class DAGSchedulerSuite extends FunSuite
       override def getPreferredLocations(split: Split) = Nil
       override def toString = "DAGSchedulerSuite Local RDD"
     }
-    resetExpecting {
-      expectGetLocations()
-      // deliberately expect no stages to be submitted
-    }
-    whenExecuting {
-      assert(submitRdd(rdd, true) === Array(42))
-    }
+    submitRdd(rdd, true)
+    expectJobResult(Array(42))
   }
 
   test("run trivial job w/ dependency") {
     val baseRdd = makeRdd(1, Nil)
     val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
-    resetExpecting {
-      expectGetLocations()
-      expectStage(finalRdd, List( (Success, 42) ))
-    }
-    whenExecuting {
-      assert(submitRdd(finalRdd) === Array(42))
-    }
+    val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+    respondToTaskSet(taskSet, List( (Success, 42) ))
+    expectJobResult(Array(42))
   }
 
-  test("location preferences w/ dependency") {
+  test("cache location preferences w/ dependency") {
     val baseRdd = makeRdd(1, Nil)
     val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
-    resetExpecting {
-      expectGetLocations()
-      cacheLocations(baseRdd.id -> 0) =
-        Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
-      expectStage(finalRdd, List( (Success, 42) ),
-                  Some(List(Seq("hostA", "hostB"))))
-    }
-    whenExecuting {
-      assert(submitRdd(finalRdd) === Array(42))
-    }
+    cacheLocations(baseRdd.id -> 0) =
+      Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
+    val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+    expectTaskSetLocations(taskSet, List(Seq("hostA", "hostB")))
+    respondToTaskSet(taskSet, List( (Success, 42) ))
+    expectJobResult(Array(42))
   }
 
   test("trivial job failure") {
     val rdd = makeRdd(1, Nil)
-    resetExpecting {
-      expectGetLocations()
-      expectStageAnd(rdd, List()) { taskSet => scheduler.taskSetFailed(taskSet, "test failure") }
-    }
-    whenExecuting(taskScheduler, blockManagerMaster) {
-      intercept[SparkException] { submitRdd(rdd) }
-    }
+    val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+    runEvent(TaskSetFailed(taskSet, "test failure"))
+    expectJobException()
   }
 
   test("run trivial shuffle") {
@@ -304,20 +345,17 @@ class DAGSchedulerSuite extends FunSuite
     val shuffleId = shuffleDep.shuffleId
     val reduceRdd = makeRdd(1, List(shuffleDep))
 
-    resetExpecting {
-      expectGetLocations()
-      expectStage(shuffleMapRdd, List(
+    val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+    val secondStage = interceptStage(reduceRdd) {
+      respondToTaskSet(firstStage, List(
         (Success, makeMapStatus("hostA", 1)),
         (Success, makeMapStatus("hostB", 1))
       ))
-      expectStageAnd(reduceRdd, List( (Success, 42) )) { _ =>
-        w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
-                   Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) }
-      }
-    }
-    whenExecuting {
-      assert(submitRdd(reduceRdd) === Array(42))
     }
+    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+           Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+    respondToTaskSet(secondStage, List( (Success, 42) ))
+    expectJobResult(Array(42))
   }
 
   test("run trivial shuffle with fetch failure") {
@@ -326,28 +364,32 @@ class DAGSchedulerSuite extends FunSuite
     val shuffleId = shuffleDep.shuffleId
     val reduceRdd = makeRdd(2, List(shuffleDep))
 
-    resetExpecting {
-      expectGetLocations()
-      expectStage(shuffleMapRdd, List(
+    val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+    val secondStage = interceptStage(reduceRdd) {
+      respondToTaskSet(firstStage, List(
         (Success, makeMapStatus("hostA", 1)),
         (Success, makeMapStatus("hostB", 1))
       ))
+    }
+    resetExpecting {
       blockManagerMaster.removeExecutor("exec-hostA")
-      expectStage(reduceRdd, List(
+    }
+    whenExecuting {
+      respondToTaskSet(secondStage, List(
         (Success, 42),
         (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)
       ))
-      // partial recompute
-      expectStage(shuffleMapRdd, List( (Success, makeMapStatus("hostA", 1)) ))
-      expectStageAnd(reduceRdd, List( (Success, 43) )) { _ =>
-        w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
-                   Array(makeBlockManagerId("hostA"),
-                         makeBlockManagerId("hostB"))) }
-      }
     }
-    whenExecuting {
-      assert(submitRdd(reduceRdd) === Array(42, 43))
+    val thirdStage = interceptStage(shuffleMapRdd) {
+      scheduler.resubmitFailedStages()
+    }
+    val fourthStage = interceptStage(reduceRdd) {
+      respondToTaskSet(thirdStage, List( (Success, makeMapStatus("hostA", 1)) ))
     }
+    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+                   Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+    respondToTaskSet(fourthStage, List( (Success, 43) ))
+    expectJobResult(Array(42, 43))
   }
 
   test("ignore late map task completions") {
@@ -356,63 +398,64 @@ class DAGSchedulerSuite extends FunSuite
     val shuffleId = shuffleDep.shuffleId
     val reduceRdd = makeRdd(2, List(shuffleDep))
 
+    val taskSet = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+    val oldGeneration = mapOutputTracker.getGeneration
     resetExpecting {
-      expectGetLocations()
-      expectStageAnd(shuffleMapRdd, List(
-        (Success, makeMapStatus("hostA", 1))
-      )) { taskSet =>
-        val newGeneration = mapOutputTracker.getGeneration + 1
-        scheduler.executorLost("exec-hostA")
-        val noAccum = Map[Long, Any]()
-        // We rely on the event queue being ordered and increasing the generation number by 1
-        // should be ignored for being too old
-        scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum)
-        // should work because it's a non-failed host
-        scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum)
-        // should be ignored for being too old
-        scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum)
-        // should be ignored (not end the stage) because it's too old
-        scheduler.taskEnded(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum)
-        taskSet.tasks(1).generation = newGeneration
-        scheduler.taskEnded(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum)
-      }
       blockManagerMaster.removeExecutor("exec-hostA")
-      expectStageAnd(reduceRdd, List(
-        (Success, 42), (Success, 43)
-      )) { _ =>
-        w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
-                   Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) }
-      }
     }
     whenExecuting {
-      assert(submitRdd(reduceRdd) === Array(42, 43))
-    }
+      runEvent(ExecutorLost("exec-hostA"))
+    }
+    val newGeneration = mapOutputTracker.getGeneration
+    assert(newGeneration > oldGeneration)
+    val noAccum = Map[Long, Any]()
+    // We rely on the event queue being ordered and increasing the generation number by 1
+    // should be ignored for being too old
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+    // should work because it's a non-failed host
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum))
+    // should be ignored for being too old
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+    taskSet.tasks(1).generation = newGeneration
+    val secondStage = interceptStage(reduceRdd) {
+      runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum))
+    }
+    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+           Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+    respondToTaskSet(secondStage, List( (Success, 42), (Success, 43) ))
+    expectJobResult(Array(42, 43))
   }
 
-  test("run trivial shuffle with out-of-band failure") {
+  test("run trivial shuffle with out-of-band failure and retry") {
     val shuffleMapRdd = makeRdd(2, Nil)
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
     val shuffleId = shuffleDep.shuffleId
     val reduceRdd = makeRdd(1, List(shuffleDep))
+
+    val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
     resetExpecting {
-      expectGetLocations()
       blockManagerMaster.removeExecutor("exec-hostA")
-      expectStageAnd(shuffleMapRdd, List(
+    }
+    whenExecuting {
+      runEvent(ExecutorLost("exec-hostA"))
+    }
+    // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
+    // rather than marking it is as failed and waiting.
+    val secondStage = interceptStage(shuffleMapRdd) {
+      respondToTaskSet(firstStage, List(
         (Success, makeMapStatus("hostA", 1)),
         (Success, makeMapStatus("hostB", 1))
-      )) { _ => scheduler.executorLost("exec-hostA") }
-      expectStage(shuffleMapRdd, List(
-        (Success, makeMapStatus("hostC", 1))
       ))
-      expectStageAnd(reduceRdd, List( (Success, 42) )) { _ =>
-        w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
-                   Array(makeBlockManagerId("hostC"),
-                         makeBlockManagerId("hostB"))) }
-      }
     }
-    whenExecuting {
-      assert(submitRdd(reduceRdd) === Array(42))
+    val thirdStage = interceptStage(reduceRdd) {
+      respondToTaskSet(secondStage, List(
+        (Success, makeMapStatus("hostC", 1))
+      ))
     }
+    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+           Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+    respondToTaskSet(thirdStage, List( (Success, 42) ))
+    expectJobResult(Array(42))
   }
 
   test("recursive shuffle failures") {
@@ -422,34 +465,42 @@ class DAGSchedulerSuite extends FunSuite
     val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
     val finalRdd = makeRdd(1, List(shuffleDepTwo))
 
-    resetExpecting {
-      expectGetLocations()
-      expectStage(shuffleOneRdd, List(
-        (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostB", 1))
+    val firstStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+    val secondStage = interceptStage(shuffleTwoRdd) {
+      respondToTaskSet(firstStage, List(
+        (Success, makeMapStatus("hostA", 2)),
+        (Success, makeMapStatus("hostB", 2))
       ))
-      expectStage(shuffleTwoRdd, List(
+    }
+    val thirdStage = interceptStage(finalRdd) {
+      respondToTaskSet(secondStage, List(
         (Success, makeMapStatus("hostA", 1)),
         (Success, makeMapStatus("hostC", 1))
       ))
+    }
+    resetExpecting {
       blockManagerMaster.removeExecutor("exec-hostA")
-      expectStage(finalRdd, List(
+    }
+    whenExecuting {
+      respondToTaskSet(thirdStage, List(
         (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
       ))
-      // triggers a partial recompute of the first stage, then the second
-      expectStage(shuffleOneRdd, List(
-        (Success, makeMapStatus("hostA", 1))
+    }
+    val recomputeOne = interceptStage(shuffleOneRdd) {
+      scheduler.resubmitFailedStages
+    }
+    val recomputeTwo = interceptStage(shuffleTwoRdd) {
+      respondToTaskSet(recomputeOne, List(
+        (Success, makeMapStatus("hostA", 2))
       ))
-      expectStage(shuffleTwoRdd, List(
+    }
+    val finalStage = interceptStage(finalRdd) {
+      respondToTaskSet(recomputeTwo, List(
         (Success, makeMapStatus("hostA", 1))
       ))
-      expectStage(finalRdd, List(
-        (Success, 42)
-      ))
-    }
-    whenExecuting {
-      assert(submitRdd(finalRdd) === Array(42))
     }
+    respondToTaskSet(finalStage, List( (Success, 42) ))
+    expectJobResult(Array(42))
   }
 
   test("cached post-shuffle") {
@@ -459,35 +510,41 @@ class DAGSchedulerSuite extends FunSuite
     val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
     val finalRdd = makeRdd(1, List(shuffleDepTwo))
 
-    resetExpecting {
-      expectGetLocations()
-      expectStage(shuffleOneRdd, List(
+    val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+    cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+    cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+    val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+      respondToTaskSet(firstShuffleStage, List(
+        (Success, makeMapStatus("hostA", 2)),
+        (Success, makeMapStatus("hostB", 2))
+      ))
+    }
+    val reduceStage = interceptStage(finalRdd) {
+      respondToTaskSet(secondShuffleStage, List(
         (Success, makeMapStatus("hostA", 1)),
         (Success, makeMapStatus("hostB", 1))
       ))
-      expectStageAnd(shuffleTwoRdd, List(
-        (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostC", 1))
-      )){ _ =>
-        cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
-        cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
-      }
+    }
+    resetExpecting {
       blockManagerMaster.removeExecutor("exec-hostA")
-      expectStage(finalRdd, List(
+    }
+    whenExecuting {
+      respondToTaskSet(reduceStage, List(
         (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
       ))
-      // since we have a cached copy of the missing split of shuffleTwoRdd, we shouldn't
-      // immediately try to rerun shuffleOneRdd:
-      expectStage(shuffleTwoRdd, List(
+    }
+    // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
+    val recomputeTwo = interceptStage(shuffleTwoRdd) {
+      scheduler.resubmitFailedStages()
+    }
+    expectTaskSetLocations(recomputeTwo, Seq(Seq("hostD")))
+    val finalRetry = interceptStage(finalRdd) {
+      respondToTaskSet(recomputeTwo, List(
         (Success, makeMapStatus("hostD", 1))
-      ), Some(Seq(List("hostD"))))
-      expectStage(finalRdd, List(
-        (Success, 42)
       ))
     }
-    whenExecuting {
-      assert(submitRdd(finalRdd) === Array(42))
-    }
+    respondToTaskSet(finalRetry, List( (Success, 42) ))
+    expectJobResult(Array(42))
   }
 
   test("cached post-shuffle but fails") {
@@ -497,45 +554,58 @@ class DAGSchedulerSuite extends FunSuite
     val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
     val finalRdd = makeRdd(1, List(shuffleDepTwo))
 
-    resetExpecting {
-      expectGetLocations()
-      expectStage(shuffleOneRdd, List(
+    val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+    cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+    cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+    val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+      respondToTaskSet(firstShuffleStage, List(
+        (Success, makeMapStatus("hostA", 2)),
+        (Success, makeMapStatus("hostB", 2))
+      ))
+    }
+    val reduceStage = interceptStage(finalRdd) {
+      respondToTaskSet(secondShuffleStage, List(
         (Success, makeMapStatus("hostA", 1)),
         (Success, makeMapStatus("hostB", 1))
       ))
-      expectStageAnd(shuffleTwoRdd, List(
-        (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostC", 1))
-      )){ _ =>
-        cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
-        cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
-      }
+    }
+    resetExpecting {
       blockManagerMaster.removeExecutor("exec-hostA")
-      expectStage(finalRdd, List(
+    }
+    whenExecuting {
+      respondToTaskSet(reduceStage, List(
         (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
       ))
-      // since we have a cached copy of the missing split of shuffleTwoRdd, we shouldn't
-      // immediately try to rerun shuffleOneRdd:
-      expectStageAnd(shuffleTwoRdd, List(
-        (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null)
-      ), Some(Seq(List("hostD")))) { _ =>
-        w {
-          intercept[FetchFailedException]{
-            mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0)
-          }
-        }
-        cacheLocations.remove(shuffleTwoRdd.id -> 0)
-      }
-      // after that fetch failure, we should refetch the cache locations and try to recompute
-      // the whole chain. Note that we will ignore that a fetch failure previously occured on
-      // this host.
-      expectStage(shuffleOneRdd, List( (Success, makeMapStatus("hostA", 1)) ))
-      expectStage(shuffleTwoRdd, List( (Success, makeMapStatus("hostA", 1)) ))
-      expectStage(finalRdd, List( (Success, 42) ))
     }
-    whenExecuting {
-      assert(submitRdd(finalRdd) === Array(42))
+    val recomputeTwoCached = interceptStage(shuffleTwoRdd) {
+      scheduler.resubmitFailedStages()
+    }
+    expectTaskSetLocations(recomputeTwoCached, Seq(Seq("hostD")))
+    intercept[FetchFailedException]{
+      mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0)
+    }
+
+    // Simulate the shuffle input data failing to be cached.
+    cacheLocations.remove(shuffleTwoRdd.id -> 0)
+    respondToTaskSet(recomputeTwoCached, List(
+      (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null)
+    ))
+
+    // After the fetch failure, DAGScheduler should recheck the cache and decide to resubmit
+    // everything.
+    val recomputeOne = interceptStage(shuffleOneRdd) {
+      scheduler.resubmitFailedStages()
     }
+    // We use hostA here to make sure DAGScheduler doesn't think it's still dead.
+    val recomputeTwoUncached = interceptStage(shuffleTwoRdd) {
+      respondToTaskSet(recomputeOne, List( (Success, makeMapStatus("hostA", 1)) ))
+    }
+    expectTaskSetLocations(recomputeTwoUncached, Seq(Seq[String]()))
+    val finalRetry = interceptStage(finalRdd) {
+      respondToTaskSet(recomputeTwoUncached, List( (Success, makeMapStatus("hostA", 1)) ))
+
+    }
+    respondToTaskSet(finalRetry, List( (Success, 42) ))
+    expectJobResult(Array(42))
   }
 }
-