diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ade372be092aeca01605e6cb8eb3bb5b25550a61..995862ece5944d57202ffc006a961f8c37f5ca0d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -353,10 +353,12 @@ class DAGScheduler( if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) - for (i <- 0 until locs.length) { - stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing + (0 until locs.length).foreach { i => + if (locs(i) ne null) { + // locs(i) will be null if missing + stage.addOutputLoc(i, locs(i)) + } } - stage.numAvailableOutputs = locs.count(_ != null) } else { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown @@ -894,7 +896,7 @@ class DAGScheduler( submitStage(finalStage) // If the whole stage has already finished, tell the listener and remove it - if (!finalStage.outputLocs.contains(Nil)) { + if (finalStage.isAvailable) { markMapStageJobAsFinished(job, mapOutputTracker.getStatistics(dependency)) } @@ -931,24 +933,12 @@ class DAGScheduler( stage.pendingPartitions.clear() // First figure out the indexes of partition ids to compute. - val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = { - stage match { - case stage: ShuffleMapStage => - val allPartitions = 0 until stage.numPartitions - val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty } - (allPartitions, filteredPartitions) - case stage: ResultStage => - val job = stage.resultOfJob.get - val allPartitions = 0 until job.numPartitions - val filteredPartitions = allPartitions.filter { id => !job.finished(id) } - (allPartitions, filteredPartitions) - } - } + val partitionsToCompute: Seq[Int] = stage.findMissingPartitions() // Create internal accumulators if the stage has no accumulators initialized. // Reset internal accumulators only if this stage is not partially submitted // Otherwise, we may override existing accumulator values from some tasks - if (stage.internalAccumulators.isEmpty || allPartitions == partitionsToCompute) { + if (stage.internalAccumulators.isEmpty || stage.numPartitions == partitionsToCompute.size) { stage.resetInternalAccumulators() } @@ -1202,7 +1192,7 @@ class DAGScheduler( clearCacheLocs() - if (shuffleStage.outputLocs.contains(Nil)) { + if (!shuffleStage.isAvailable) { // Some tasks had failed; let's resubmit this shuffleStage // TODO: Lower-level scheduler should also deal with this logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index c0451da1f0247b11e96c7a2c5fd29c8463d81de6..c1d86af7e8fb53b9f1e3ea75a6e0d606af0cb50a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -43,5 +43,10 @@ private[spark] class ResultStage( */ var resultOfJob: Option[ActiveJob] = None + override def findMissingPartitions(): Seq[Int] = { + val job = resultOfJob.get + (0 until job.numPartitions).filter(id => !job.finished(id)) + } + override def toString: String = "ResultStage " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 7d929608764036dd351836a566d6cd16c96170f0..3832d99eddaef1ab8f30199c679ea18b7ddd6698 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -48,12 +48,33 @@ private[spark] class ShuffleMapStage( /** Running map-stage jobs that were submitted to execute this stage independently (if any) */ var mapStageJobs: List[ActiveJob] = Nil + /** + * Number of partitions that have shuffle outputs. + * When this reaches [[numPartitions]], this map stage is ready. + * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`. + */ var numAvailableOutputs: Int = 0 + /** + * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs. + * This should be the same as `outputLocs.contains(Nil)`. + */ def isAvailable: Boolean = numAvailableOutputs == numPartitions + /** + * List of [[MapStatus]] for each partition. The index of the array is the map partition id, + * and each value in the array is the list of possible [[MapStatus]] for a partition + * (a single task might run multiple times). + */ val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + override def findMissingPartitions(): Seq[Int] = { + val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty) + assert(missing.size == numPartitions - numAvailableOutputs, + s"${missing.size} missing, expected ${numPartitions - numAvailableOutputs}") + missing + } + def addOutputLoc(partition: Int, status: MapStatus): Unit = { val prevList = outputLocs(partition) outputLocs(partition) = status :: prevList diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index a3829c319c48dc7d5bb6367d83745e43ceed24de..5ce4a484344f1c322119b557b74afea078f6a9f1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -61,7 +61,7 @@ private[scheduler] abstract class Stage( val callSite: CallSite) extends Logging { - val numPartitions = rdd.partitions.size + val numPartitions = rdd.partitions.length /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] @@ -138,6 +138,9 @@ private[scheduler] abstract class Stage( case stage: Stage => stage != null && stage.id == id case _ => false } + + /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ + def findMissingPartitions(): Seq[Int] } private[scheduler] object Stage {