diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index b8f2b96d7088dc4fd194580f5b885f50c5d87b55..e0fdd4597385895609440c597eb5bf99d97df70a 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -323,35 +323,60 @@ private[spark] object AccumulatorSuite {
  * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs.
  */
 private class SaveInfoListener extends SparkListener {
-  private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo]
-  private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo]
-  private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID
+  type StageId = Int
+  type StageAttemptId = Int
 
-  // Accesses must be synchronized to ensure failures in `jobCompletionCallback` are propagated
+  private val completedStageInfos = new ArrayBuffer[StageInfo]
+  private val completedTaskInfos =
+    new mutable.HashMap[(StageId, StageAttemptId), ArrayBuffer[TaskInfo]]
+
+  // Callback to call when a job completes. Parameter is job ID.
   @GuardedBy("this")
+  private var jobCompletionCallback: () => Unit = null
+  private var calledJobCompletionCallback: Boolean = false
   private var exception: Throwable = null
 
   def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq
-  def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq
+  def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.values.flatten.toSeq
+  def getCompletedTaskInfos(stageId: StageId, stageAttemptId: StageAttemptId): Seq[TaskInfo] =
+    completedTaskInfos.get((stageId, stageAttemptId)).getOrElse(Seq.empty[TaskInfo])
 
-  /** Register a callback to be called on job end. */
-  def registerJobCompletionCallback(callback: (Int => Unit)): Unit = {
-    jobCompletionCallback = callback
+  /**
+   * If `jobCompletionCallback` is set, block until the next call has finished.
+   * If the callback failed with an exception, throw it.
+   */
+  def awaitNextJobCompletion(): Unit = synchronized {
+    if (jobCompletionCallback != null) {
+      while (!calledJobCompletionCallback) {
+        wait()
+      }
+      calledJobCompletionCallback = false
+      if (exception != null) {
+        exception = null
+        throw exception
+      }
+    }
   }
 
-  /** Throw a stored exception, if any. */
-  def maybeThrowException(): Unit = synchronized {
-    if (exception != null) { throw exception }
+  /**
+   * Register a callback to be called on job end.
+   * A call to this should be followed by [[awaitNextJobCompletion]].
+   */
+  def registerJobCompletionCallback(callback: () => Unit): Unit = {
+    jobCompletionCallback = callback
   }
 
   override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized {
     if (jobCompletionCallback != null) {
       try {
-        jobCompletionCallback(jobEnd.jobId)
+        jobCompletionCallback()
       } catch {
         // Store any exception thrown here so we can throw them later in the main thread.
         // Otherwise, if `jobCompletionCallback` threw something it wouldn't fail the test.
         case NonFatal(e) => exception = e
+      } finally {
+        calledJobCompletionCallback = true
+        notify()
       }
     }
   }
@@ -361,7 +386,8 @@ private class SaveInfoListener extends SparkListener {
   }
 
   override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
-    completedTaskInfos += taskEnd.taskInfo
+    completedTaskInfos.getOrElseUpdate(
+      (taskEnd.stageId, taskEnd.stageAttemptId), new ArrayBuffer[TaskInfo]) += taskEnd.taskInfo
   }
 }
 
diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
index 630b46f828df718fc3ae0172e5cb0cdb7390e0f5..44a16e26f49351c96df32b0f62eb2c61ae33293b 100644
--- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.scheduler.AccumulableInfo
+import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.storage.{BlockId, BlockStatus}
 
 
@@ -160,7 +161,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
       iter
     }
     // Register asserts in job completion callback to avoid flakiness
-    listener.registerJobCompletionCallback { _ =>
+    listener.registerJobCompletionCallback { () =>
       val stageInfos = listener.getCompletedStageInfos
       val taskInfos = listener.getCompletedTaskInfos
       assert(stageInfos.size === 1)
@@ -179,6 +180,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
       assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
     }
     rdd.count()
+    listener.awaitNextJobCompletion()
   }
 
   test("internal accumulators in multiple stages") {
@@ -205,7 +207,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
         iter
       }
     // Register asserts in job completion callback to avoid flakiness
-    listener.registerJobCompletionCallback { _ =>
+    listener.registerJobCompletionCallback { () =>
     // We ran 3 stages, and the accumulator values should be distinct
       val stageInfos = listener.getCompletedStageInfos
       assert(stageInfos.size === 3)
@@ -220,13 +222,66 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
     rdd.count()
   }
 
-  // TODO: these two tests are incorrect; they don't actually trigger stage retries.
-  ignore("internal accumulators in fully resubmitted stages") {
-    testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks
-  }
+  test("internal accumulators in resubmitted stages") {
+    val listener = new SaveInfoListener
+    val numPartitions = 10
+    sc = new SparkContext("local", "test")
+    sc.addSparkListener(listener)
+
+    // Simulate fetch failures in order to trigger a stage retry. Here we run 1 job with
+    // 2 stages. On the second stage, we trigger a fetch failure on the first stage attempt.
+    // This should retry both stages in the scheduler. Note that we only want to fail the
+    // first stage attempt because we want the stage to eventually succeed.
+    val x = sc.parallelize(1 to 100, numPartitions)
+      .mapPartitions { iter => TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1; iter }
+      .groupBy(identity)
+    val sid = x.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle.shuffleId
+    val rdd = x.mapPartitionsWithIndex { case (i, iter) =>
+      // Fail the first stage attempt. Here we use the task attempt ID to determine this.
+      // This job runs 2 stages, and we're in the second stage. Therefore, any task attempt
+      // ID that's < 2 * numPartitions belongs to the first attempt of this stage.
+      val taskContext = TaskContext.get()
+      val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2
+      if (isFirstStageAttempt) {
+        throw new FetchFailedException(
+          SparkEnv.get.blockManager.blockManagerId,
+          sid,
+          taskContext.partitionId(),
+          taskContext.partitionId(),
+          "simulated fetch failure")
+      } else {
+        iter
+      }
+    }
 
-  ignore("internal accumulators in partially resubmitted stages") {
-    testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset
+    // Register asserts in job completion callback to avoid flakiness
+    listener.registerJobCompletionCallback { () =>
+      val stageInfos = listener.getCompletedStageInfos
+      assert(stageInfos.size === 4) // 1 shuffle map stage + 1 result stage, both are retried
+      val mapStageId = stageInfos.head.stageId
+      val mapStageInfo1stAttempt = stageInfos.head
+      val mapStageInfo2ndAttempt = {
+        stageInfos.tail.find(_.stageId == mapStageId).getOrElse {
+          fail("expected two attempts of the same shuffle map stage.")
+        }
+      }
+      val stageAccum1stAttempt = findTestAccum(mapStageInfo1stAttempt.accumulables.values)
+      val stageAccum2ndAttempt = findTestAccum(mapStageInfo2ndAttempt.accumulables.values)
+      // Both map stages should have succeeded, since the fetch failure happened in the
+      // result stage, not the map stage. This means we should get the accumulator updates
+      // from all partitions.
+      assert(stageAccum1stAttempt.value.get.asInstanceOf[Long] === numPartitions)
+      assert(stageAccum2ndAttempt.value.get.asInstanceOf[Long] === numPartitions)
+      // Because this test resubmitted the map stage with all missing partitions, we should have
+      // created a fresh set of internal accumulators in the 2nd stage attempt. Assert this is
+      // the case by comparing the accumulator IDs between the two attempts.
+      // Note: it would be good to also test the case where the map stage is resubmitted where
+      // only a subset of the original partitions are missing. However, this scenario is very
+      // difficult to construct without potentially introducing flakiness.
+      assert(stageAccum1stAttempt.id != stageAccum2ndAttempt.id)
+    }
+    rdd.count()
+    listener.awaitNextJobCompletion()
   }
 
   test("internal accumulators are registered for cleanups") {
@@ -257,63 +312,6 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
     }
   }
 
-  /**
-   * Test whether internal accumulators are merged properly if some tasks fail.
-   * TODO: make this actually retry the stage.
-   */
-  private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = {
-    val listener = new SaveInfoListener
-    val numPartitions = 10
-    val numFailedPartitions = (0 until numPartitions).count(failCondition)
-    // This says use 1 core and retry tasks up to 2 times
-    sc = new SparkContext("local[1, 2]", "test")
-    sc.addSparkListener(listener)
-    val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) =>
-      val taskContext = TaskContext.get()
-      taskContext.taskMetrics().getAccum(TEST_ACCUM) += 1
-      // Fail the first attempts of a subset of the tasks
-      if (failCondition(i) && taskContext.attemptNumber() == 0) {
-        throw new Exception("Failing a task intentionally.")
-      }
-      iter
-    }
-    // Register asserts in job completion callback to avoid flakiness
-    listener.registerJobCompletionCallback { _ =>
-      val stageInfos = listener.getCompletedStageInfos
-      val taskInfos = listener.getCompletedTaskInfos
-      assert(stageInfos.size === 1)
-      assert(taskInfos.size === numPartitions + numFailedPartitions)
-      val stageAccum = findTestAccum(stageInfos.head.accumulables.values)
-      // If all partitions failed, then we would resubmit the whole stage again and create a
-      // fresh set of internal accumulators. Otherwise, these internal accumulators do count
-      // failed values, so we must include the failed values.
-      val expectedAccumValue =
-        if (numPartitions == numFailedPartitions) {
-          numPartitions
-        } else {
-          numPartitions + numFailedPartitions
-        }
-      assert(stageAccum.value.get.asInstanceOf[Long] === expectedAccumValue)
-      val taskAccumValues = taskInfos.flatMap { taskInfo =>
-        if (!taskInfo.failed) {
-          // If a task succeeded, its update value should always be 1
-          val taskAccum = findTestAccum(taskInfo.accumulables)
-          assert(taskAccum.update.isDefined)
-          assert(taskAccum.update.get.asInstanceOf[Long] === 1L)
-          assert(taskAccum.value.isDefined)
-          Some(taskAccum.value.get.asInstanceOf[Long])
-        } else {
-          // If a task failed, we should not get its accumulator values
-          assert(taskInfo.accumulables.isEmpty)
-          None
-        }
-      }
-      assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
-    }
-    rdd.count()
-    listener.maybeThrowException()
-  }
-
   /**
    * A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup.
    */