diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 4e1250a14d7ca4f1552ada2a81051b6cafd630c0..d4e0d6db0cf36f25e0275ea991b78b01f8960a57 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -141,7 +141,13 @@ class DAGScheduler(
 
   private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]]
   private[scheduler] val stageIdToStage = new HashMap[Int, Stage]
-  private[scheduler] val shuffleToMapStage = new HashMap[Int, ShuffleMapStage]
+  /**
+   * Mapping from shuffle dependency ID to the ShuffleMapStage that will generate the data for
+   * that dependency. Only includes stages that are part of currently running job (when the job(s)
+   * that require the shuffle stage complete, the mapping will be removed, and the only record of
+   * the shuffle data will be in the MapOutputTracker).
+   */
+  private[scheduler] val shuffleIdToMapStage = new HashMap[Int, ShuffleMapStage]
   private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob]
 
   // Stages we need to run whose parents aren't done
@@ -276,86 +282,55 @@ class DAGScheduler(
   }
 
   /**
-   * Get or create a shuffle map stage for the given shuffle dependency's map side.
+   * Gets a shuffle map stage if one exists in shuffleIdToMapStage. Otherwise, if the
+   * shuffle map stage doesn't already exist, this method will create the shuffle map stage in
+   * addition to any missing ancestor shuffle map stages.
    */
-  private def getShuffleMapStage(
+  private def getOrCreateShuffleMapStage(
       shuffleDep: ShuffleDependency[_, _, _],
       firstJobId: Int): ShuffleMapStage = {
-    shuffleToMapStage.get(shuffleDep.shuffleId) match {
-      case Some(stage) => stage
+    shuffleIdToMapStage.get(shuffleDep.shuffleId) match {
+      case Some(stage) =>
+        stage
+
       case None =>
-        // We are going to register ancestor shuffle dependencies
-        getAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>
-          if (!shuffleToMapStage.contains(dep.shuffleId)) {
-            shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId)
+        // Create stages for all missing ancestor shuffle dependencies.
+        getMissingAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>
+          // Even though getMissingAncestorShuffleDependencies only returns shuffle dependencies
+          // that were not already in shuffleIdToMapStage, it's possible that by the time we
+          // get to a particular dependency in the foreach loop, it's been added to
+          // shuffleIdToMapStage by the stage creation process for an earlier dependency. See
+          // SPARK-13902 for more information.
+          if (!shuffleIdToMapStage.contains(dep.shuffleId)) {
+            createShuffleMapStage(dep, firstJobId)
           }
         }
-        // Then register current shuffleDep
-        val stage = newOrUsedShuffleStage(shuffleDep, firstJobId)
-        shuffleToMapStage(shuffleDep.shuffleId) = stage
-        stage
+        // Finally, create a stage for the given shuffle dependency.
+        createShuffleMapStage(shuffleDep, firstJobId)
     }
   }
 
   /**
-   * Helper function to eliminate some code re-use when creating new stages.
+   * Creates a ShuffleMapStage that generates the given shuffle dependency's partitions. If a
+   * previously run stage generated the same shuffle data, this function will copy the output
+   * locations that are still available from the previous shuffle to avoid unnecessarily
+   * regenerating data.
    */
-  private def getParentStagesAndId(rdd: RDD[_], firstJobId: Int): (List[Stage], Int) = {
-    val parentStages = getParentStages(rdd, firstJobId)
+  def createShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): ShuffleMapStage = {
+    val rdd = shuffleDep.rdd
+    val numTasks = rdd.partitions.length
+    val parents = getOrCreateParentStages(rdd, jobId)
     val id = nextStageId.getAndIncrement()
-    (parentStages, id)
-  }
-
-  /**
-   * Create a ShuffleMapStage as part of the (re)-creation of a shuffle map stage in
-   * newOrUsedShuffleStage.  The stage will be associated with the provided firstJobId.
-   * Production of shuffle map stages should always use newOrUsedShuffleStage, not
-   * newShuffleMapStage directly.
-   */
-  private def newShuffleMapStage(
-      rdd: RDD[_],
-      numTasks: Int,
-      shuffleDep: ShuffleDependency[_, _, _],
-      firstJobId: Int,
-      callSite: CallSite): ShuffleMapStage = {
-    val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, firstJobId)
-    val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages,
-      firstJobId, callSite, shuffleDep)
-
-    stageIdToStage(id) = stage
-    updateJobIdStageIdMaps(firstJobId, stage)
-    stage
-  }
+    val stage = new ShuffleMapStage(id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep)
 
-  /**
-   * Create a ResultStage associated with the provided jobId.
-   */
-  private def newResultStage(
-      rdd: RDD[_],
-      func: (TaskContext, Iterator[_]) => _,
-      partitions: Array[Int],
-      jobId: Int,
-      callSite: CallSite): ResultStage = {
-    val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId)
-    val stage = new ResultStage(id, rdd, func, partitions, parentStages, jobId, callSite)
     stageIdToStage(id) = stage
+    shuffleIdToMapStage(shuffleDep.shuffleId) = stage
     updateJobIdStageIdMaps(jobId, stage)
-    stage
-  }
 
-  /**
-   * Create a shuffle map Stage for the given RDD.  The stage will also be associated with the
-   * provided firstJobId.  If a stage for the shuffleId existed previously so that the shuffleId is
-   * present in the MapOutputTracker, then the number and location of available outputs are
-   * recovered from the MapOutputTracker
-   */
-  private def newOrUsedShuffleStage(
-      shuffleDep: ShuffleDependency[_, _, _],
-      firstJobId: Int): ShuffleMapStage = {
-    val rdd = shuffleDep.rdd
-    val numTasks = rdd.partitions.length
-    val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite)
     if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
+      // A previously run stage generated partitions for this shuffle, so for each output
+      // that's still available, copy information about that output location to the new stage
+      // (so we don't unnecessarily re-compute that data).
       val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
       val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
       (0 until locs.length).foreach { i =>
@@ -373,18 +348,36 @@ class DAGScheduler(
     stage
   }
 
+  /**
+   * Create a ResultStage associated with the provided jobId.
+   */
+  private def createResultStage(
+      rdd: RDD[_],
+      func: (TaskContext, Iterator[_]) => _,
+      partitions: Array[Int],
+      jobId: Int,
+      callSite: CallSite): ResultStage = {
+    val parents = getOrCreateParentStages(rdd, jobId)
+    val id = nextStageId.getAndIncrement()
+    val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite)
+    stageIdToStage(id) = stage
+    updateJobIdStageIdMaps(jobId, stage)
+    stage
+  }
+
   /**
    * Get or create the list of parent stages for a given RDD.  The new Stages will be created with
    * the provided firstJobId.
    */
-  private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
+  private def getOrCreateParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
     getShuffleDependencies(rdd).map { shuffleDep =>
-      getShuffleMapStage(shuffleDep, firstJobId)
+      getOrCreateShuffleMapStage(shuffleDep, firstJobId)
     }.toList
   }
 
   /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */
-  private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
+  private def getMissingAncestorShuffleDependencies(
+      rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
     val ancestors = new Stack[ShuffleDependency[_, _, _]]
     val visited = new HashSet[RDD[_]]
     // We are manually maintaining a stack here to prevent StackOverflowError
@@ -396,7 +389,7 @@ class DAGScheduler(
       if (!visited(toVisit)) {
         visited += toVisit
         getShuffleDependencies(toVisit).foreach { shuffleDep =>
-          if (!shuffleToMapStage.contains(shuffleDep.shuffleId)) {
+          if (!shuffleIdToMapStage.contains(shuffleDep.shuffleId)) {
             ancestors.push(shuffleDep)
             waitingForVisit.push(shuffleDep.rdd)
           } // Otherwise, the dependency and its ancestors have already been registered.
@@ -453,7 +446,7 @@ class DAGScheduler(
           for (dep <- rdd.dependencies) {
             dep match {
               case shufDep: ShuffleDependency[_, _, _] =>
-                val mapStage = getShuffleMapStage(shufDep, stage.firstJobId)
+                val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId)
                 if (!mapStage.isAvailable) {
                   missing += mapStage
                 }
@@ -482,8 +475,7 @@ class DAGScheduler(
         val s = stages.head
         s.jobIds += jobId
         jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id
-        val parents: List[Stage] = getParentStages(s.rdd, jobId)
-        val parentsWithoutThisJobId = parents.filter { ! _.jobIds.contains(jobId) }
+        val parentsWithoutThisJobId = s.parents.filter { ! _.jobIds.contains(jobId) }
         updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail)
       }
     }
@@ -516,8 +508,8 @@ class DAGScheduler(
                   logDebug("Removing running stage %d".format(stageId))
                   runningStages -= stage
                 }
-                for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) {
-                  shuffleToMapStage.remove(k)
+                for ((k, v) <- shuffleIdToMapStage.find(_._2 == stage)) {
+                  shuffleIdToMapStage.remove(k)
                 }
                 if (waitingStages.contains(stage)) {
                   logDebug("Removing stage %d from waiting set.".format(stageId))
@@ -843,7 +835,7 @@ class DAGScheduler(
     try {
       // New stage creation may throw an exception if, for example, jobs are run on a
       // HadoopRDD whose underlying HDFS files have been deleted.
-      finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite)
+      finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)
     } catch {
       case e: Exception =>
         logWarning("Creating new stage failed due to exception - job: " + jobId, e)
@@ -881,7 +873,7 @@ class DAGScheduler(
     try {
       // New stage creation may throw an exception if, for example, jobs are run on a
       // HadoopRDD whose underlying HDFS files have been deleted.
-      finalStage = getShuffleMapStage(dependency, jobId)
+      finalStage = getOrCreateShuffleMapStage(dependency, jobId)
     } catch {
       case e: Exception =>
         logWarning("Creating new stage failed due to exception - job: " + jobId, e)
@@ -966,7 +958,6 @@ class DAGScheduler(
         case s: ShuffleMapStage =>
           partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
         case s: ResultStage =>
-          val job = s.activeJob.get
           partitionsToCompute.map { id =>
             val p = s.partitions(id)
             (id, getPreferredLocs(stage.rdd, p))
@@ -1028,7 +1019,6 @@ class DAGScheduler(
           }
 
         case stage: ResultStage =>
-          val job = stage.activeJob.get
           partitionsToCompute.map { id =>
             val p: Int = stage.partitions(id)
             val part = stage.rdd.partitions(p)
@@ -1244,7 +1234,7 @@ class DAGScheduler(
 
       case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) =>
         val failedStage = stageIdToStage(task.stageId)
-        val mapStage = shuffleToMapStage(shuffleId)
+        val mapStage = shuffleIdToMapStage(shuffleId)
 
         if (failedStage.latestInfo.attemptId != task.stageAttemptId) {
           logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
@@ -1334,14 +1324,14 @@ class DAGScheduler(
 
       if (!env.blockManager.externalShuffleServiceEnabled || fetchFailed) {
         // TODO: This will be really slow if we keep accumulating shuffle map stages
-        for ((shuffleId, stage) <- shuffleToMapStage) {
+        for ((shuffleId, stage) <- shuffleIdToMapStage) {
           stage.removeOutputsOnExecutor(execId)
           mapOutputTracker.registerMapOutputs(
             shuffleId,
             stage.outputLocInMapOutputTrackerFormat(),
             changeEpoch = true)
         }
-        if (shuffleToMapStage.isEmpty) {
+        if (shuffleIdToMapStage.isEmpty) {
           mapOutputTracker.incrementEpoch()
         }
         clearCacheLocs()
@@ -1496,7 +1486,7 @@ class DAGScheduler(
         for (dep <- rdd.dependencies) {
           dep match {
             case shufDep: ShuffleDependency[_, _, _] =>
-              val mapStage = getShuffleMapStage(shufDep, stage.firstJobId)
+              val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId)
               if (!mapStage.isAvailable) {
                 waitingForVisit.push(mapStage.rdd)
               }  // Otherwise there's no need to follow the dependency back
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index ab8e95314fdf3dd2b0a3b477ca7c9f24798c8b60..63a494006c91158c4b3bc710b0220405840aab9f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -360,12 +360,12 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
 
     submit(rddD, Array(0))
 
-    assert(scheduler.shuffleToMapStage.size === 3)
+    assert(scheduler.shuffleIdToMapStage.size === 3)
     assert(scheduler.activeJobs.size === 1)
 
-    val mapStageA = scheduler.shuffleToMapStage(s_A)
-    val mapStageB = scheduler.shuffleToMapStage(s_B)
-    val mapStageC = scheduler.shuffleToMapStage(s_C)
+    val mapStageA = scheduler.shuffleIdToMapStage(s_A)
+    val mapStageB = scheduler.shuffleIdToMapStage(s_B)
+    val mapStageC = scheduler.shuffleIdToMapStage(s_C)
     val finalStage = scheduler.activeJobs.head.finalStage
 
     assert(mapStageA.parents.isEmpty)
@@ -2072,7 +2072,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
     assert(scheduler.jobIdToStageIds.isEmpty)
     assert(scheduler.stageIdToStage.isEmpty)
     assert(scheduler.runningStages.isEmpty)
-    assert(scheduler.shuffleToMapStage.isEmpty)
+    assert(scheduler.shuffleIdToMapStage.isEmpty)
     assert(scheduler.waitingStages.isEmpty)
     assert(scheduler.outputCommitCoordinator.isEmpty)
   }