Skip to content
Snippets Groups Projects
Commit 0fbecc73 authored by Kay Ousterhout's avatar Kay Ousterhout
Browse files

[SPARK-19537] Move pendingPartitions to ShuffleMapStage.

The pendingPartitions instance variable should be moved to ShuffleMapStage,
because it is only used by ShuffleMapStages. This change is purely refactoring
and does not change functionality.

I fixed this in an attempt to clarify some of the discussion around #16620, which I was having trouble reasoning about.  I stole the helpful comment Imran wrote for pendingPartitions and used it here.

cc squito markhamstra jinxing64

Author: Kay Ousterhout <kayousterhout@gmail.com>

Closes #16876 from kayousterhout/SPARK-19537.
parent 226d3884
No related branches found
No related tags found
No related merge requests found
......@@ -932,8 +932,6 @@ class DAGScheduler(
/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingPartitions.clear()
// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
......@@ -1013,9 +1011,11 @@ class DAGScheduler(
val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
stage match {
case stage: ShuffleMapStage =>
stage.pendingPartitions.clear()
partitionsToCompute.map { id =>
val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
stage.pendingPartitions += id
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
Option(sc.applicationId), sc.applicationAttemptId)
......@@ -1039,9 +1039,8 @@ class DAGScheduler(
}
if (tasks.size > 0) {
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
stage.pendingPartitions ++= tasks.map(_.partitionId)
logDebug("New pending partitions: " + stage.pendingPartitions)
logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " +
s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})")
taskScheduler.submitTasks(new TaskSet(
tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))
stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
......@@ -1147,7 +1146,6 @@ class DAGScheduler(
val stage = stageIdToStage(task.stageId)
event.reason match {
case Success =>
stage.pendingPartitions -= task.partitionId
task match {
case rt: ResultTask[_, _] =>
// Cast to ResultStage here because it's part of the ResultTask
......@@ -1183,6 +1181,7 @@ class DAGScheduler(
case smt: ShuffleMapTask =>
val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
shuffleStage.pendingPartitions -= task.partitionId
updateAccumulators(event)
val status = event.result.asInstanceOf[MapStatus]
val execId = status.location.executorId
......@@ -1235,7 +1234,14 @@ class DAGScheduler(
case Resubmitted =>
logInfo("Resubmitted " + task + ", so marking it as still running")
stage.pendingPartitions += task.partitionId
stage match {
case sms: ShuffleMapStage =>
sms.pendingPartitions += task.partitionId
case _ =>
assert(false, "TaskSetManagers should only send Resubmitted task statuses for " +
"tasks in ShuffleMapStages.")
}
case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
......
......@@ -17,6 +17,8 @@
package org.apache.spark.scheduler
import scala.collection.mutable.HashSet
import org.apache.spark.ShuffleDependency
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.BlockManagerId
......@@ -47,6 +49,17 @@ private[spark] class ShuffleMapStage(
private[this] var _numAvailableOutputs: Int = 0
/**
* Partitions that either haven't yet been computed, or that were computed on an executor
* that has since been lost, so should be re-computed. This variable is used by the
* DAGScheduler to determine when a stage has completed. Task successes in both the active
* attempt for the stage or in earlier attempts for this stage can cause paritition ids to get
* removed from pendingPartitions. As a result, this variable may be inconsistent with the pending
* tasks in the TaskSetManager for the active attempt for the stage (the partitions stored here
* will always be a subset of the partitions that the TaskSetManager thinks are pending).
*/
val pendingPartitions = new HashSet[Int]
/**
* List of [[MapStatus]] for each partition. The index of the array is the map partition id,
* and each value in the array is the list of possible [[MapStatus]] for a partition
......
......@@ -67,8 +67,6 @@ private[scheduler] abstract class Stage(
/** Set of jobs that this stage belongs to. */
val jobIds = new HashSet[Int]
val pendingPartitions = new HashSet[Int]
/** The ID to use for the next new attempt for this stage. */
private var nextAttemptId: Int = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment