diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index ade372be092aeca01605e6cb8eb3bb5b25550a61..995862ece5944d57202ffc006a961f8c37f5ca0d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -353,10 +353,12 @@ class DAGScheduler(
     if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
       val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
       val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
-      for (i <- 0 until locs.length) {
-        stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing
+      (0 until locs.length).foreach { i =>
+        if (locs(i) ne null) {
+          // locs(i) will be null if missing
+          stage.addOutputLoc(i, locs(i))
+        }
       }
-      stage.numAvailableOutputs = locs.count(_ != null)
     } else {
       // Kind of ugly: need to register RDDs with the cache and map output tracker here
       // since we can't do it in the RDD constructor because # of partitions is unknown
@@ -894,7 +896,7 @@ class DAGScheduler(
     submitStage(finalStage)
 
     // If the whole stage has already finished, tell the listener and remove it
-    if (!finalStage.outputLocs.contains(Nil)) {
+    if (finalStage.isAvailable) {
       markMapStageJobAsFinished(job, mapOutputTracker.getStatistics(dependency))
     }
 
@@ -931,24 +933,12 @@ class DAGScheduler(
     stage.pendingPartitions.clear()
 
     // First figure out the indexes of partition ids to compute.
-    val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = {
-      stage match {
-        case stage: ShuffleMapStage =>
-          val allPartitions = 0 until stage.numPartitions
-          val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty }
-          (allPartitions, filteredPartitions)
-        case stage: ResultStage =>
-          val job = stage.resultOfJob.get
-          val allPartitions = 0 until job.numPartitions
-          val filteredPartitions = allPartitions.filter { id => !job.finished(id) }
-          (allPartitions, filteredPartitions)
-      }
-    }
+    val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
 
     // Create internal accumulators if the stage has no accumulators initialized.
     // Reset internal accumulators only if this stage is not partially submitted
     // Otherwise, we may override existing accumulator values from some tasks
-    if (stage.internalAccumulators.isEmpty || allPartitions == partitionsToCompute) {
+    if (stage.internalAccumulators.isEmpty || stage.numPartitions == partitionsToCompute.size) {
       stage.resetInternalAccumulators()
     }
 
@@ -1202,7 +1192,7 @@ class DAGScheduler(
 
               clearCacheLocs()
 
-              if (shuffleStage.outputLocs.contains(Nil)) {
+              if (!shuffleStage.isAvailable) {
                 // Some tasks had failed; let's resubmit this shuffleStage
                 // TODO: Lower-level scheduler should also deal with this
                 logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name +
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
index c0451da1f0247b11e96c7a2c5fd29c8463d81de6..c1d86af7e8fb53b9f1e3ea75a6e0d606af0cb50a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
@@ -43,5 +43,10 @@ private[spark] class ResultStage(
    */
   var resultOfJob: Option[ActiveJob] = None
 
+  override def findMissingPartitions(): Seq[Int] = {
+    val job = resultOfJob.get
+    (0 until job.numPartitions).filter(id => !job.finished(id))
+  }
+
   override def toString: String = "ResultStage " + id
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
index 7d929608764036dd351836a566d6cd16c96170f0..3832d99eddaef1ab8f30199c679ea18b7ddd6698 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
@@ -48,12 +48,33 @@ private[spark] class ShuffleMapStage(
   /** Running map-stage jobs that were submitted to execute this stage independently (if any) */
   var mapStageJobs: List[ActiveJob] = Nil
 
+  /**
+   * Number of partitions that have shuffle outputs.
+   * When this reaches [[numPartitions]], this map stage is ready.
+   * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`.
+   */
   var numAvailableOutputs: Int = 0
 
+  /**
+   * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs.
+   * This should be the same as `outputLocs.contains(Nil)`.
+   */
   def isAvailable: Boolean = numAvailableOutputs == numPartitions
 
+  /**
+   * List of [[MapStatus]] for each partition. The index of the array is the map partition id,
+   * and each value in the array is the list of possible [[MapStatus]] for a partition
+   * (a single task might run multiple times).
+   */
   val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
 
+  override def findMissingPartitions(): Seq[Int] = {
+    val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty)
+    assert(missing.size == numPartitions - numAvailableOutputs,
+      s"${missing.size} missing, expected ${numPartitions - numAvailableOutputs}")
+    missing
+  }
+
   def addOutputLoc(partition: Int, status: MapStatus): Unit = {
     val prevList = outputLocs(partition)
     outputLocs(partition) = status :: prevList
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index a3829c319c48dc7d5bb6367d83745e43ceed24de..5ce4a484344f1c322119b557b74afea078f6a9f1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -61,7 +61,7 @@ private[scheduler] abstract class Stage(
     val callSite: CallSite)
   extends Logging {
 
-  val numPartitions = rdd.partitions.size
+  val numPartitions = rdd.partitions.length
 
   /** Set of jobs that this stage belongs to. */
   val jobIds = new HashSet[Int]
@@ -138,6 +138,9 @@ private[scheduler] abstract class Stage(
     case stage: Stage => stage != null && stage.id == id
     case _ => false
   }
+
+  /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */
+  def findMissingPartitions(): Seq[Int]
 }
 
 private[scheduler] object Stage {