diff --git a/core/pom.xml b/core/pom.xml
index 873e8a1d0fe33892296d2fe8fa230eb3c2c8c83a..66c62151feb757d5da56edaf9c53eb8028a4675e 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -98,6 +98,11 @@
       <artifactId>scalacheck_${scala.version}</artifactId>
       <scope>test</scope>
     </dependency>
+    <dependency>
+      <groupId>org.easymock</groupId>
+      <artifactId>easymock</artifactId>
+      <scope>test</scope>
+    </dependency>
     <dependency>
       <groupId>com.novocode</groupId>
       <artifactId>junit-interface</artifactId>
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index ddbf8f95d9390bc37654328ccf07950eff6abb3f..2ed458c6fe3e288deb847f9281942cc0b531c9c2 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -187,6 +187,7 @@ class SparkContext(
   taskScheduler.start()
 
   private var dagScheduler = new DAGScheduler(taskScheduler)
+  dagScheduler.start()
 
   /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
   val hadoopConfiguration = {
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 908a22b2dfa0e39cc1aedc622142601146a93481..8cfc08e5acac3501c677d16cf55b0a734da5676e 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -23,7 +23,16 @@ import util.{MetadataCleaner, TimeStampedHashMap}
  * and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
  */
 private[spark]
-class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
+class DAGScheduler(
+    taskSched: TaskScheduler,
+    mapOutputTracker: MapOutputTracker,
+    blockManagerMaster: BlockManagerMaster,
+    env: SparkEnv)
+  extends TaskSchedulerListener with Logging {
+
+  def this(taskSched: TaskScheduler) {
+    this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
+  }
   taskSched.setListener(this)
 
   // Called by TaskScheduler to report task completions or failures.
@@ -66,10 +75,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
 
   var cacheLocs = new HashMap[Int, Array[List[String]]]
 
-  val env = SparkEnv.get
-  val mapOutputTracker = env.mapOutputTracker
-  val blockManagerMaster = env.blockManager.master
-
   // For tracking failed nodes, we use the MapOutputTracker's generation number, which is
   // sent with every task. When we detect a node failing, we note the current generation number
   // and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask
@@ -90,12 +95,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
   val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
 
   // Start a thread to run the DAGScheduler event loop
-  new Thread("DAGScheduler") {
-    setDaemon(true)
-    override def run() {
-      DAGScheduler.this.run()
-    }
-  }.start()
+  def start() {
+    new Thread("DAGScheduler") {
+      setDaemon(true)
+      override def run() {
+        DAGScheduler.this.run()
+      }
+    }.start()
+  }
 
   private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
     if (!cacheLocs.contains(rdd.id)) {
@@ -198,6 +205,28 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
     missing.toList
   }
 
+  /** Returns (and does not) submit a JobSubmitted event suitable to run a given job, and
+   * a JobWaiter whose getResult() method will return the result of the job when it is complete.
+   *
+   * The job is assumed to have at least one partition; zero partition jobs should be handled
+   * without a JobSubmitted event.
+   */
+  private[scheduler] def prepareJob[T, U: ClassManifest](
+      finalRdd: RDD[T],
+      func: (TaskContext, Iterator[T]) => U,
+      partitions: Seq[Int],
+      callSite: String,
+      allowLocal: Boolean,
+      resultHandler: (Int, U) => Unit)
+    : (JobSubmitted, JobWaiter[U]) =
+  {
+    assert(partitions.size > 0)
+    val waiter = new JobWaiter(partitions.size, resultHandler)
+    val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
+    val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)
+    return (toSubmit, waiter)
+  }
+
   def runJob[T, U: ClassManifest](
       finalRdd: RDD[T],
       func: (TaskContext, Iterator[T]) => U,
@@ -209,9 +238,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
     if (partitions.size == 0) {
       return
     }
-    val waiter = new JobWaiter(partitions.size, resultHandler)
-    val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
-    eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter))
+    val (toSubmit, waiter) = prepareJob(
+        finalRdd, func, partitions, callSite, allowLocal, resultHandler)
+    eventQueue.put(toSubmit)
     waiter.awaitResult() match {
       case JobSucceeded => {}
       case JobFailed(exception: Exception) =>
@@ -235,6 +264,81 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
     return listener.awaitResult()    // Will throw an exception if the job fails
   }
 
+  /** Process one event retrieved from the event queue.
+   * Returns true if we should stop the event loop.
+   */
+  private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
+    event match {
+      case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
+        val runId = nextRunId.getAndIncrement()
+        val finalStage = newStage(finalRDD, None, runId)
+        val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
+        clearCacheLocs()
+        logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
+                " output partitions (allowLocal=" + allowLocal + ")")
+        logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
+        logInfo("Parents of final stage: " + finalStage.parents)
+        logInfo("Missing parents: " + getMissingParentStages(finalStage))
+        if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
+          // Compute very short actions like first() or take() with no parent stages locally.
+          runLocally(job)
+        } else {
+          activeJobs += job
+          resultStageToJob(finalStage) = job
+          submitStage(finalStage)
+        }
+
+      case ExecutorLost(execId) =>
+        handleExecutorLost(execId)
+
+      case completion: CompletionEvent =>
+        handleTaskCompletion(completion)
+
+      case TaskSetFailed(taskSet, reason) =>
+        abortStage(idToStage(taskSet.stageId), reason)
+
+      case StopDAGScheduler =>
+        // Cancel any active jobs
+        for (job <- activeJobs) {
+          val error = new SparkException("Job cancelled because SparkContext was shut down")
+          job.listener.jobFailed(error)
+        }
+        return true
+    }
+    return false
+  }
+
+  /** Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
+   * the last fetch failure.
+   */
+  private[scheduler] def resubmitFailedStages() {
+    logInfo("Resubmitting failed stages")
+    clearCacheLocs()
+    val failed2 = failed.toArray
+    failed.clear()
+    for (stage <- failed2.sortBy(_.priority)) {
+      submitStage(stage)
+    }
+  }
+  
+  /** Check for waiting or failed stages which are now eligible for resubmission.
+   * Ordinarily run on every iteration of the event loop.
+   */
+  private[scheduler] def submitWaitingStages() {
+    // TODO: We might want to run this less often, when we are sure that something has become
+    // runnable that wasn't before.
+    logTrace("Checking for newly runnable parent stages")
+    logTrace("running: " + running)
+    logTrace("waiting: " + waiting)
+    logTrace("failed: " + failed)
+    val waiting2 = waiting.toArray
+    waiting.clear()
+    for (stage <- waiting2.sortBy(_.priority)) {
+      submitStage(stage)
+    }
+  }
+
+
   /**
    * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
    * events and responds by launching tasks. This runs in a dedicated thread and receives events
@@ -245,77 +349,26 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
 
     while (true) {
       val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
-      val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
       if (event != null) {
         logDebug("Got event of type " + event.getClass.getName)
       }
 
-      event match {
-        case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
-          val runId = nextRunId.getAndIncrement()
-          val finalStage = newStage(finalRDD, None, runId)
-          val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
-          clearCacheLocs()
-          logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
-                  " output partitions")
-          logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
-          logInfo("Parents of final stage: " + finalStage.parents)
-          logInfo("Missing parents: " + getMissingParentStages(finalStage))
-          if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
-            // Compute very short actions like first() or take() with no parent stages locally.
-            runLocally(job)
-          } else {
-            activeJobs += job
-            resultStageToJob(finalStage) = job
-            submitStage(finalStage)
-          }
-
-        case ExecutorLost(execId) =>
-          handleExecutorLost(execId)
-
-        case completion: CompletionEvent =>
-          handleTaskCompletion(completion)
-
-        case TaskSetFailed(taskSet, reason) =>
-          abortStage(idToStage(taskSet.stageId), reason)
-
-        case StopDAGScheduler =>
-          // Cancel any active jobs
-          for (job <- activeJobs) {
-            val error = new SparkException("Job cancelled because SparkContext was shut down")
-            job.listener.jobFailed(error)
-          }
+      if (event != null) {
+        if (processEvent(event)) {
           return
-
-        case null =>
-          // queue.poll() timed out, ignore it
+        }
       }
 
+      val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
       // Periodically resubmit failed stages if some map output fetches have failed and we have
       // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
       // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
       // the same time, so we want to make sure we've identified all the reduce tasks that depend
       // on the failed node.
       if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
-        logInfo("Resubmitting failed stages")
-        clearCacheLocs()
-        val failed2 = failed.toArray
-        failed.clear()
-        for (stage <- failed2.sortBy(_.priority)) {
-          submitStage(stage)
-        }
+        resubmitFailedStages
       } else {
-        // TODO: We might want to run this less often, when we are sure that something has become
-        // runnable that wasn't before.
-        logTrace("Checking for newly runnable parent stages")
-        logTrace("running: " + running)
-        logTrace("waiting: " + waiting)
-        logTrace("failed: " + failed)
-        val waiting2 = waiting.toArray
-        waiting.clear()
-        for (stage <- waiting2.sortBy(_.priority)) {
-          submitStage(stage)
-        }
+        submitWaitingStages
       }
     }
   }
@@ -547,7 +600,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
     if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) {
       failedGeneration(execId) = currentGeneration
       logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration))
-      env.blockManager.master.removeExecutor(execId)
+      blockManagerMaster.removeExecutor(execId)
       // TODO: This will be really slow if we keep accumulating shuffle map stages
       for ((shuffleId, stage) <- shuffleToMapStage) {
         stage.removeOutputsOnExecutor(execId)
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..83663ac702a5be3d7ea9c89c6ca9b6054adbfee7
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -0,0 +1,663 @@
+package spark.scheduler
+
+import scala.collection.mutable.{Map, HashMap}
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import org.scalatest.concurrent.TimeLimitedTests
+import org.scalatest.mock.EasyMockSugar
+import org.scalatest.time.{Span, Seconds}
+
+import org.easymock.EasyMock._
+import org.easymock.Capture
+import org.easymock.EasyMock
+import org.easymock.{IAnswer, IArgumentMatcher}
+
+import akka.actor.ActorSystem
+
+import spark.storage.BlockManager
+import spark.storage.BlockManagerId
+import spark.storage.BlockManagerMaster
+import spark.{Dependency, ShuffleDependency, OneToOneDependency}
+import spark.FetchFailedException
+import spark.MapOutputTracker
+import spark.RDD
+import spark.SparkContext
+import spark.SparkException
+import spark.Split
+import spark.TaskContext
+import spark.TaskEndReason
+
+import spark.{FetchFailed, Success}
+
+/**
+ * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
+ * rather than spawning an event loop thread as happens in the real code. They use EasyMock
+ * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are
+ * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead
+ * host notifications are sent). In addition, tests may check for side effects on a non-mocked
+ * MapOutputTracker instance.
+ *
+ * Tests primarily consist of running DAGScheduler#processEvent and
+ * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
+ * and capturing the resulting TaskSets from the mock TaskScheduler.
+ */
+class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests {
+
+  // impose a time limit on this test in case we don't let the job finish, in which case
+  // JobWaiter#getResult will hang.
+  override val timeLimit = Span(5, Seconds)
+
+  val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite")
+  var scheduler: DAGScheduler = null
+  val taskScheduler = mock[TaskScheduler]
+  val blockManagerMaster = mock[BlockManagerMaster]
+  var mapOutputTracker: MapOutputTracker = null
+  var schedulerThread: Thread = null
+  var schedulerException: Throwable = null
+
+  /**
+   * Set of EasyMock argument matchers that match a TaskSet for a given RDD.
+   * We cache these so we do not create duplicate matchers for the same RDD.
+   * This allows us to easily setup a sequence of expectations for task sets for
+   * that RDD.
+   */
+  val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher]
+
+  /**
+   * Set of cache locations to return from our mock BlockManagerMaster.
+   * Keys are (rdd ID, partition ID). Anything not present will return an empty
+   * list of cache locations silently.
+   */
+  val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
+
+  /**
+   * JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which
+   * will only submit one job) from needing to explicitly track it.
+   */
+  var lastJobWaiter: JobWaiter[Int] = null
+
+  /**
+   * Array into which we are accumulating the results from the last job asynchronously.
+   */
+  var lastJobResult: Array[Int] = null
+
+  /**
+   * Tell EasyMockSugar what mock objects we want to be configured by expecting {...}
+   * and whenExecuting {...} */
+  implicit val mocks = MockObjects(taskScheduler, blockManagerMaster)
+
+  /**
+   * Utility function to reset mocks and set expectations on them. EasyMock wants mock objects
+   * to be reset after each time their expectations are set, and we tend to check mock object
+   * calls over a single call to DAGScheduler.
+   *
+   * We also set a default expectation here that blockManagerMaster.getLocations can be called
+   * and will return values from cacheLocations.
+   */
+  def resetExpecting(f: => Unit) {
+    reset(taskScheduler)
+    reset(blockManagerMaster)
+    expecting {
+      expectGetLocations()
+      f
+    }
+  }
+
+  before {
+    taskSetMatchers.clear()
+    cacheLocations.clear()
+    val actorSystem = ActorSystem("test")
+    mapOutputTracker = new MapOutputTracker(actorSystem, true)
+    resetExpecting {
+      taskScheduler.setListener(anyObject())
+    }
+    whenExecuting {
+      scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null)
+    }
+  }
+
+  after {
+    assert(scheduler.processEvent(StopDAGScheduler))
+    resetExpecting {
+      taskScheduler.stop()
+    }
+    whenExecuting {
+      scheduler.stop()
+    }
+    sc.stop()
+    System.clearProperty("spark.master.port")
+  }
+
+  def makeBlockManagerId(host: String): BlockManagerId =
+    BlockManagerId("exec-" + host, host, 12345)
+
+  /**
+   * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
+   * This is a pair RDD type so it can always be used in ShuffleDependencies.
+   */
+  type MyRDD = RDD[(Int, Int)]
+
+  /**
+   * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+   * preferredLocations (if any) that are passed to them. They are deliberately not executable
+   * so we can test that DAGScheduler does not try to execute RDDs locally.
+   */
+  def makeRdd(
+        numSplits: Int,
+        dependencies: List[Dependency[_]],
+        locations: Seq[Seq[String]] = Nil
+      ): MyRDD = {
+    val maxSplit = numSplits - 1
+    return new MyRDD(sc, dependencies) {
+      override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
+        throw new RuntimeException("should not be reached")
+      override def getSplits() = (0 to maxSplit).map(i => new Split {
+        override def index = i
+      }).toArray
+      override def getPreferredLocations(split: Split): Seq[String] =
+        if (locations.isDefinedAt(split.index))
+          locations(split.index)
+        else
+          Nil
+      override def toString: String = "DAGSchedulerSuiteRDD " + id
+    }
+  }
+
+  /**
+   * EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task
+   * is from a particular RDD.
+   */
+  def taskSetForRdd(rdd: MyRDD): TaskSet = {
+    val matcher = taskSetMatchers.getOrElseUpdate(rdd,
+      new IArgumentMatcher {
+        override def matches(actual: Any): Boolean = {
+          val taskSet = actual.asInstanceOf[TaskSet]
+          taskSet.tasks(0) match {
+            case rt: ResultTask[_, _] => rt.rdd.id == rdd.id
+            case smt: ShuffleMapTask => smt.rdd.id == rdd.id
+            case _ => false
+          }
+        }
+        override def appendTo(buf: StringBuffer) {
+          buf.append("taskSetForRdd(" + rdd + ")")
+        }
+      })
+    EasyMock.reportMatcher(matcher)
+    return null
+  }
+
+  /**
+   * Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from
+   * cacheLocations.
+   */
+  def expectGetLocations(): Unit = {
+    EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])).
+        andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] {
+      override def answer(): Seq[Seq[BlockManagerId]] = {
+        val blocks = getCurrentArguments()(0).asInstanceOf[Array[String]]
+        return blocks.map { name =>
+          val pieces = name.split("_")
+          if (pieces(0) == "rdd") {
+            val key = pieces(1).toInt -> pieces(2).toInt
+            if (cacheLocations.contains(key)) {
+              cacheLocations(key)
+            } else {
+              Seq[BlockManagerId]()
+            }
+          } else {
+            Seq[BlockManagerId]()
+          }
+        }.toSeq
+      }
+    }).anyTimes()
+  }
+
+  /**
+   * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
+   * the scheduler not to exit.
+   *
+   * After processing the event, submit waiting stages as is done on most iterations of the
+   * DAGScheduler event loop.
+   */
+  def runEvent(event: DAGSchedulerEvent) {
+    assert(!scheduler.processEvent(event))
+    scheduler.submitWaitingStages()
+  }
+
+  /**
+   * Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be
+   * called from a resetExpecting { ... } block.
+   *
+   * Returns a easymock Capture that will contain the task set after the stage is submitted.
+   * Most tests should use interceptStage() instead of this directly.
+   */
+  def expectStage(rdd: MyRDD): Capture[TaskSet] = {
+    val taskSetCapture = new Capture[TaskSet]
+    taskScheduler.submitTasks(and(capture(taskSetCapture), taskSetForRdd(rdd)))
+    return taskSetCapture
+  }
+
+  /**
+   * Expect the supplied code snippet to submit a stage for the specified RDD.
+   * Return the resulting TaskSet. First marks all the tasks are belonging to the
+   * current MapOutputTracker generation.
+   */
+  def interceptStage(rdd: MyRDD)(f: => Unit): TaskSet = {
+    var capture: Capture[TaskSet] = null
+    resetExpecting {
+      capture = expectStage(rdd)
+    }
+    whenExecuting {
+      f
+    }
+    val taskSet = capture.getValue
+    for (task <- taskSet.tasks) {
+      task.generation = mapOutputTracker.getGeneration
+    }
+    return taskSet
+  }
+
+  /**
+   * Send the given CompletionEvent messages for the tasks in the TaskSet.
+   */
+  def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
+    assert(taskSet.tasks.size >= results.size)
+    for ((result, i) <- results.zipWithIndex) {
+      if (i < taskSet.tasks.size) {
+        runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()))
+      }
+    }
+  }
+
+  /**
+   * Assert that the supplied TaskSet has exactly the given preferredLocations.
+   */
+  def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
+    assert(locations.size === taskSet.tasks.size)
+    for ((expectLocs, taskLocs) <-
+            taskSet.tasks.map(_.preferredLocations).zip(locations)) {
+      assert(expectLocs === taskLocs)
+    }
+  }
+
+  /**
+   * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
+   * below, we do not expect this function to ever be executed; instead, we will return results
+   * directly through CompletionEvents.
+   */
+  def jobComputeFunc(context: TaskContext, it: Iterator[(Int, Int)]): Int =
+     it.next._1.asInstanceOf[Int]
+
+
+  /**
+   * Start a job to compute the given RDD. Returns the JobWaiter that will
+   * collect the result of the job via callbacks from DAGScheduler.
+   */
+  def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): (JobWaiter[Int], Array[Int]) = {
+    val resultArray = new Array[Int](rdd.splits.size)
+    val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int](
+        rdd,
+        jobComputeFunc,
+        (0 to (rdd.splits.size - 1)),
+        "test-site",
+        allowLocal,
+        (i: Int, value: Int) => resultArray(i) = value
+    )
+    lastJobWaiter = waiter
+    lastJobResult = resultArray
+    runEvent(toSubmit)
+    return (waiter, resultArray)
+  }
+
+  /**
+   * Assert that a job we started has failed.
+   */
+  def expectJobException(waiter: JobWaiter[Int] = lastJobWaiter) {
+    waiter.awaitResult() match {
+      case JobSucceeded => fail()
+      case JobFailed(_) => return
+    }
+  }
+
+  /**
+   * Assert that a job we started has succeeded and has the given result.
+   */
+  def expectJobResult(expected: Array[Int], waiter: JobWaiter[Int] = lastJobWaiter,
+                      result: Array[Int] = lastJobResult) {
+    waiter.awaitResult match {
+      case JobSucceeded =>
+        assert(expected === result)
+      case JobFailed(_) =>
+        fail()
+    }
+  }
+
+  def makeMapStatus(host: String, reduces: Int): MapStatus =
+    new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
+
+  test("zero split job") {
+    val rdd = makeRdd(0, Nil)
+    var numResults = 0
+    def accumulateResult(partition: Int, value: Int) {
+      numResults += 1
+    }
+    scheduler.runJob(rdd, jobComputeFunc, Seq(), "test-site", false, accumulateResult)
+    assert(numResults === 0)
+  }
+
+  test("run trivial job") {
+    val rdd = makeRdd(1, Nil)
+    val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+    respondToTaskSet(taskSet, List( (Success, 42) ))
+    expectJobResult(Array(42))
+  }
+
+  test("local job") {
+    val rdd = new MyRDD(sc, Nil) {
+      override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
+        Array(42 -> 0).iterator
+      override def getSplits() = Array( new Split { override def index = 0 } )
+      override def getPreferredLocations(split: Split) = Nil
+      override def toString = "DAGSchedulerSuite Local RDD"
+    }
+    submitRdd(rdd, true)
+    expectJobResult(Array(42))
+  }
+
+  test("run trivial job w/ dependency") {
+    val baseRdd = makeRdd(1, Nil)
+    val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+    val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+    respondToTaskSet(taskSet, List( (Success, 42) ))
+    expectJobResult(Array(42))
+  }
+
+  test("cache location preferences w/ dependency") {
+    val baseRdd = makeRdd(1, Nil)
+    val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+    cacheLocations(baseRdd.id -> 0) =
+      Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
+    val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+    expectTaskSetLocations(taskSet, List(Seq("hostA", "hostB")))
+    respondToTaskSet(taskSet, List( (Success, 42) ))
+    expectJobResult(Array(42))
+  }
+
+  test("trivial job failure") {
+    val rdd = makeRdd(1, Nil)
+    val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+    runEvent(TaskSetFailed(taskSet, "test failure"))
+    expectJobException()
+  }
+
+  test("run trivial shuffle") {
+    val shuffleMapRdd = makeRdd(2, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleId = shuffleDep.shuffleId
+    val reduceRdd = makeRdd(1, List(shuffleDep))
+
+    val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+    val secondStage = interceptStage(reduceRdd) {
+      respondToTaskSet(firstStage, List(
+        (Success, makeMapStatus("hostA", 1)),
+        (Success, makeMapStatus("hostB", 1))
+      ))
+    }
+    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+           Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+    respondToTaskSet(secondStage, List( (Success, 42) ))
+    expectJobResult(Array(42))
+  }
+
+  test("run trivial shuffle with fetch failure") {
+    val shuffleMapRdd = makeRdd(2, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleId = shuffleDep.shuffleId
+    val reduceRdd = makeRdd(2, List(shuffleDep))
+
+    val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+    val secondStage = interceptStage(reduceRdd) {
+      respondToTaskSet(firstStage, List(
+        (Success, makeMapStatus("hostA", 1)),
+        (Success, makeMapStatus("hostB", 1))
+      ))
+    }
+    resetExpecting {
+      blockManagerMaster.removeExecutor("exec-hostA")
+    }
+    whenExecuting {
+      respondToTaskSet(secondStage, List(
+        (Success, 42),
+        (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)
+      ))
+    }
+    val thirdStage = interceptStage(shuffleMapRdd) {
+      scheduler.resubmitFailedStages()
+    }
+    val fourthStage = interceptStage(reduceRdd) {
+      respondToTaskSet(thirdStage, List( (Success, makeMapStatus("hostA", 1)) ))
+    }
+    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+                   Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+    respondToTaskSet(fourthStage, List( (Success, 43) ))
+    expectJobResult(Array(42, 43))
+  }
+
+  test("ignore late map task completions") {
+    val shuffleMapRdd = makeRdd(2, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleId = shuffleDep.shuffleId
+    val reduceRdd = makeRdd(2, List(shuffleDep))
+
+    val taskSet = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+    val oldGeneration = mapOutputTracker.getGeneration
+    resetExpecting {
+      blockManagerMaster.removeExecutor("exec-hostA")
+    }
+    whenExecuting {
+      runEvent(ExecutorLost("exec-hostA"))
+    }
+    val newGeneration = mapOutputTracker.getGeneration
+    assert(newGeneration > oldGeneration)
+    val noAccum = Map[Long, Any]()
+    // We rely on the event queue being ordered and increasing the generation number by 1
+    // should be ignored for being too old
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+    // should work because it's a non-failed host
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum))
+    // should be ignored for being too old
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+    taskSet.tasks(1).generation = newGeneration
+    val secondStage = interceptStage(reduceRdd) {
+      runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum))
+    }
+    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+           Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+    respondToTaskSet(secondStage, List( (Success, 42), (Success, 43) ))
+    expectJobResult(Array(42, 43))
+  }
+
+  test("run trivial shuffle with out-of-band failure and retry") {
+    val shuffleMapRdd = makeRdd(2, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleId = shuffleDep.shuffleId
+    val reduceRdd = makeRdd(1, List(shuffleDep))
+
+    val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+    resetExpecting {
+      blockManagerMaster.removeExecutor("exec-hostA")
+    }
+    whenExecuting {
+      runEvent(ExecutorLost("exec-hostA"))
+    }
+    // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
+    // rather than marking it is as failed and waiting.
+    val secondStage = interceptStage(shuffleMapRdd) {
+      respondToTaskSet(firstStage, List(
+        (Success, makeMapStatus("hostA", 1)),
+        (Success, makeMapStatus("hostB", 1))
+      ))
+    }
+    val thirdStage = interceptStage(reduceRdd) {
+      respondToTaskSet(secondStage, List(
+        (Success, makeMapStatus("hostC", 1))
+      ))
+    }
+    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+           Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+    respondToTaskSet(thirdStage, List( (Success, 42) ))
+    expectJobResult(Array(42))
+  }
+
+  test("recursive shuffle failures") {
+    val shuffleOneRdd = makeRdd(2, Nil)
+    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+    val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+    val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+    val firstStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+    val secondStage = interceptStage(shuffleTwoRdd) {
+      respondToTaskSet(firstStage, List(
+        (Success, makeMapStatus("hostA", 2)),
+        (Success, makeMapStatus("hostB", 2))
+      ))
+    }
+    val thirdStage = interceptStage(finalRdd) {
+      respondToTaskSet(secondStage, List(
+        (Success, makeMapStatus("hostA", 1)),
+        (Success, makeMapStatus("hostC", 1))
+      ))
+    }
+    resetExpecting {
+      blockManagerMaster.removeExecutor("exec-hostA")
+    }
+    whenExecuting {
+      respondToTaskSet(thirdStage, List(
+        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+      ))
+    }
+    val recomputeOne = interceptStage(shuffleOneRdd) {
+      scheduler.resubmitFailedStages()
+    }
+    val recomputeTwo = interceptStage(shuffleTwoRdd) {
+      respondToTaskSet(recomputeOne, List(
+        (Success, makeMapStatus("hostA", 2))
+      ))
+    }
+    val finalStage = interceptStage(finalRdd) {
+      respondToTaskSet(recomputeTwo, List(
+        (Success, makeMapStatus("hostA", 1))
+      ))
+    }
+    respondToTaskSet(finalStage, List( (Success, 42) ))
+    expectJobResult(Array(42))
+  }
+
+  test("cached post-shuffle") {
+    val shuffleOneRdd = makeRdd(2, Nil)
+    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+    val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+    val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+    val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+    cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+    cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+    val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+      respondToTaskSet(firstShuffleStage, List(
+        (Success, makeMapStatus("hostA", 2)),
+        (Success, makeMapStatus("hostB", 2))
+      ))
+    }
+    val reduceStage = interceptStage(finalRdd) {
+      respondToTaskSet(secondShuffleStage, List(
+        (Success, makeMapStatus("hostA", 1)),
+        (Success, makeMapStatus("hostB", 1))
+      ))
+    }
+    resetExpecting {
+      blockManagerMaster.removeExecutor("exec-hostA")
+    }
+    whenExecuting {
+      respondToTaskSet(reduceStage, List(
+        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+      ))
+    }
+    // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
+    val recomputeTwo = interceptStage(shuffleTwoRdd) {
+      scheduler.resubmitFailedStages()
+    }
+    expectTaskSetLocations(recomputeTwo, Seq(Seq("hostD")))
+    val finalRetry = interceptStage(finalRdd) {
+      respondToTaskSet(recomputeTwo, List(
+        (Success, makeMapStatus("hostD", 1))
+      ))
+    }
+    respondToTaskSet(finalRetry, List( (Success, 42) ))
+    expectJobResult(Array(42))
+  }
+
+  test("cached post-shuffle but fails") {
+    val shuffleOneRdd = makeRdd(2, Nil)
+    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+    val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+    val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+    val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+    cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+    cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+    val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+      respondToTaskSet(firstShuffleStage, List(
+        (Success, makeMapStatus("hostA", 2)),
+        (Success, makeMapStatus("hostB", 2))
+      ))
+    }
+    val reduceStage = interceptStage(finalRdd) {
+      respondToTaskSet(secondShuffleStage, List(
+        (Success, makeMapStatus("hostA", 1)),
+        (Success, makeMapStatus("hostB", 1))
+      ))
+    }
+    resetExpecting {
+      blockManagerMaster.removeExecutor("exec-hostA")
+    }
+    whenExecuting {
+      respondToTaskSet(reduceStage, List(
+        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+      ))
+    }
+    val recomputeTwoCached = interceptStage(shuffleTwoRdd) {
+      scheduler.resubmitFailedStages()
+    }
+    expectTaskSetLocations(recomputeTwoCached, Seq(Seq("hostD")))
+    intercept[FetchFailedException]{
+      mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0)
+    }
+
+    // Simulate the shuffle input data failing to be cached.
+    cacheLocations.remove(shuffleTwoRdd.id -> 0)
+    respondToTaskSet(recomputeTwoCached, List(
+      (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null)
+    ))
+
+    // After the fetch failure, DAGScheduler should recheck the cache and decide to resubmit
+    // everything.
+    val recomputeOne = interceptStage(shuffleOneRdd) {
+      scheduler.resubmitFailedStages()
+    }
+    // We use hostA here to make sure DAGScheduler doesn't think it's still dead.
+    val recomputeTwoUncached = interceptStage(shuffleTwoRdd) {
+      respondToTaskSet(recomputeOne, List( (Success, makeMapStatus("hostA", 1)) ))
+    }
+    expectTaskSetLocations(recomputeTwoUncached, Seq(Seq[String]()))
+    val finalRetry = interceptStage(finalRdd) {
+      respondToTaskSet(recomputeTwoUncached, List( (Success, makeMapStatus("hostA", 1)) ))
+
+    }
+    respondToTaskSet(finalRetry, List( (Success, 42) ))
+    expectJobResult(Array(42))
+  }
+}
diff --git a/pom.xml b/pom.xml
index c6b9012dc67b23abd4f8fc62f3cb58fa5ebf7d11..7e06cae052b58d6ebc1e5240f4810601287d9ca2 100644
--- a/pom.xml
+++ b/pom.xml
@@ -273,6 +273,12 @@
         <version>1.8</version>
         <scope>test</scope>
       </dependency>
+      <dependency>
+        <groupId>org.easymock</groupId>
+        <artifactId>easymock</artifactId>
+        <version>3.1</version>
+        <scope>test</scope>
+      </dependency>
       <dependency>
         <groupId>org.scalacheck</groupId>
         <artifactId>scalacheck_${scala.version}</artifactId>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 03b8094f7dce58b38b40d42564ba3a4467af64c9..af8b5ba01745b59f2c8c82f3c6c6bdb58a01fa45 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -92,7 +92,8 @@ object SparkBuild extends Build {
       "org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011",
       "org.scalatest" %% "scalatest" % "1.8" % "test",
       "org.scalacheck" %% "scalacheck" % "1.9" % "test",
-      "com.novocode" % "junit-interface" % "0.8" % "test"
+      "com.novocode" % "junit-interface" % "0.8" % "test",
+      "org.easymock" % "easymock" % "3.1" % "test"
     ),
     parallelExecution := false,
     /* Workaround for issue #206 (fixed after SBT 0.11.0) */