Skip to content
Snippets Groups Projects
Commit 2587ce16 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Fixed a deadlock that occured with MesosScheduler due to an earlier

synchronization change
parent 98f008b7
No related branches found
No related tags found
No related merge requests found
......@@ -58,6 +58,7 @@ private trait DAGScheduler extends Scheduler with Logging {
val POLL_TIMEOUT = 500L
private val completionEvents = new LinkedBlockingQueue[CompletionEvent]
private val lock = new Object
var nextStageId = 0
......@@ -164,157 +165,158 @@ private trait DAGScheduler extends Scheduler with Logging {
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
allowLocal: Boolean)
(implicit m: ClassManifest[U]): Array[U] =
synchronized {
val outputParts = partitions.toArray
val numOutputParts: Int = partitions.size
val finalStage = newStage(finalRdd, None)
val results = new Array[U](numOutputParts)
val finished = new Array[Boolean](numOutputParts)
var numFinished = 0
val waiting = new HashSet[Stage] // stages we need to run whose parents aren't done
val running = new HashSet[Stage] // stages we are running right now
val failed = new HashSet[Stage] // stages that must be resubmitted due to fetch failures
val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // missing tasks from each stage
var lastFetchFailureTime: Long = 0 // used to wait a bit to avoid repeated resubmits
SparkEnv.set(env)
updateCacheLocs()
logInfo("Final stage: " + finalStage)
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
// Optimization for short actions like first() and take() that can be computed locally
// without shipping tasks to the cluster.
if (allowLocal && finalStage.parents.size == 0 && numOutputParts == 1) {
logInfo("Computing the requested partition locally")
val split = finalRdd.splits(outputParts(0))
val taskContext = new TaskContext(finalStage.id, outputParts(0), 0)
return Array(func(taskContext, finalRdd.iterator(split)))
}
def submitStage(stage: Stage) {
if (!waiting(stage) && !running(stage)) {
val missing = getMissingParentStages(stage)
if (missing == Nil) {
logInfo("Submitting " + stage + ", which has no missing parents")
submitMissingTasks(stage)
running += stage
} else {
for (parent <- missing) {
submitStage(parent)
(implicit m: ClassManifest[U]): Array[U] = {
lock.synchronized {
val outputParts = partitions.toArray
val numOutputParts: Int = partitions.size
val finalStage = newStage(finalRdd, None)
val results = new Array[U](numOutputParts)
val finished = new Array[Boolean](numOutputParts)
var numFinished = 0
val waiting = new HashSet[Stage] // stages we need to run whose parents aren't done
val running = new HashSet[Stage] // stages we are running right now
val failed = new HashSet[Stage] // stages that must be resubmitted due to fetch failures
val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // missing tasks from each stage
var lastFetchFailureTime: Long = 0 // used to wait a bit to avoid repeated resubmits
SparkEnv.set(env)
updateCacheLocs()
logInfo("Final stage: " + finalStage)
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
// Optimization for short actions like first() and take() that can be computed locally
// without shipping tasks to the cluster.
if (allowLocal && finalStage.parents.size == 0 && numOutputParts == 1) {
logInfo("Computing the requested partition locally")
val split = finalRdd.splits(outputParts(0))
val taskContext = new TaskContext(finalStage.id, outputParts(0), 0)
return Array(func(taskContext, finalRdd.iterator(split)))
}
def submitStage(stage: Stage) {
if (!waiting(stage) && !running(stage)) {
val missing = getMissingParentStages(stage)
if (missing == Nil) {
logInfo("Submitting " + stage + ", which has no missing parents")
submitMissingTasks(stage)
running += stage
} else {
for (parent <- missing) {
submitStage(parent)
}
waiting += stage
}
waiting += stage
}
}
}
def submitMissingTasks(stage: Stage) {
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
var tasks = ArrayBuffer[Task[_]]()
if (stage == finalStage) {
for (id <- 0 until numOutputParts if (!finished(id))) {
val part = outputParts(id)
val locs = getPreferredLocs(finalRdd, part)
tasks += new ResultTask(finalStage.id, finalRdd, func, part, locs, id)
}
} else {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
val locs = getPreferredLocs(stage.rdd, p)
tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
def submitMissingTasks(stage: Stage) {
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
var tasks = ArrayBuffer[Task[_]]()
if (stage == finalStage) {
for (id <- 0 until numOutputParts if (!finished(id))) {
val part = outputParts(id)
val locs = getPreferredLocs(finalRdd, part)
tasks += new ResultTask(finalStage.id, finalRdd, func, part, locs, id)
}
} else {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
val locs = getPreferredLocs(stage.rdd, p)
tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
}
}
myPending ++= tasks
submitTasks(tasks)
}
myPending ++= tasks
submitTasks(tasks)
}
submitStage(finalStage)
while (numFinished != numOutputParts) {
val evt = completionEvents.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
val time = System.currentTimeMillis // TODO: use a pluggable clock for testability
// If we got an event off the queue, mark the task done or react to a fetch failure
if (evt != null) {
val stage = idToStage(evt.task.stageId)
pendingTasks(stage) -= evt.task
if (evt.reason == Success) {
// A task ended
logInfo("Completed " + evt.task)
Accumulators.add(evt.accumUpdates)
evt.task match {
case rt: ResultTask[_, _] =>
results(rt.outputId) = evt.result.asInstanceOf[U]
finished(rt.outputId) = true
numFinished += 1
case smt: ShuffleMapTask =>
val stage = idToStage(smt.stageId)
stage.addOutputLoc(smt.partition, evt.result.asInstanceOf[String])
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
logInfo(stage + " finished; looking for newly runnable stages")
running -= stage
if (stage.shuffleDep != None) {
mapOutputTracker.registerMapOutputs(
stage.shuffleDep.get.shuffleId,
stage.outputLocs.map(_.head).toArray)
}
updateCacheLocs()
val newlyRunnable = new ArrayBuffer[Stage]
for (stage <- waiting if getMissingParentStages(stage) == Nil) {
newlyRunnable += stage
}
waiting --= newlyRunnable
running ++= newlyRunnable
for (stage <- newlyRunnable) {
submitMissingTasks(stage)
submitStage(finalStage)
while (numFinished != numOutputParts) {
val evt = completionEvents.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
val time = System.currentTimeMillis // TODO: use a pluggable clock for testability
// If we got an event off the queue, mark the task done or react to a fetch failure
if (evt != null) {
val stage = idToStage(evt.task.stageId)
pendingTasks(stage) -= evt.task
if (evt.reason == Success) {
// A task ended
logInfo("Completed " + evt.task)
Accumulators.add(evt.accumUpdates)
evt.task match {
case rt: ResultTask[_, _] =>
results(rt.outputId) = evt.result.asInstanceOf[U]
finished(rt.outputId) = true
numFinished += 1
case smt: ShuffleMapTask =>
val stage = idToStage(smt.stageId)
stage.addOutputLoc(smt.partition, evt.result.asInstanceOf[String])
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
logInfo(stage + " finished; looking for newly runnable stages")
running -= stage
if (stage.shuffleDep != None) {
mapOutputTracker.registerMapOutputs(
stage.shuffleDep.get.shuffleId,
stage.outputLocs.map(_.head).toArray)
}
updateCacheLocs()
val newlyRunnable = new ArrayBuffer[Stage]
for (stage <- waiting if getMissingParentStages(stage) == Nil) {
newlyRunnable += stage
}
waiting --= newlyRunnable
running ++= newlyRunnable
for (stage <- newlyRunnable) {
submitMissingTasks(stage)
}
}
}
}
} else {
evt.reason match {
case FetchFailed(serverUri, shuffleId, mapId, reduceId) =>
// Mark the stage that the reducer was in as unrunnable
val failedStage = idToStage(evt.task.stageId)
running -= failedStage
failed += failedStage
// TODO: Cancel running tasks in the stage
logInfo("Marking " + failedStage + " for resubmision due to a fetch failure")
// Mark the map whose fetch failed as broken in the map stage
val mapStage = shuffleToMapStage(shuffleId)
mapStage.removeOutputLoc(mapId, serverUri)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, serverUri)
logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission")
failed += mapStage
// Remember that a fetch failed now; this is used to resubmit the broken
// stages later, after a small wait (to give other tasks the chance to fail)
lastFetchFailureTime = time
// TODO: If there are a lot of fetch failures on the same node, maybe mark all
// outputs on the node as dead.
case _ =>
// Non-fetch failure -- probably a bug in the job, so bail out
throw new SparkException("Task failed: " + evt.task + ", reason: " + evt.reason)
// TODO: Cancel all tasks that are still running
}
}
} else {
evt.reason match {
case FetchFailed(serverUri, shuffleId, mapId, reduceId) =>
// Mark the stage that the reducer was in as unrunnable
val failedStage = idToStage(evt.task.stageId)
running -= failedStage
failed += failedStage
// TODO: Cancel running tasks in the stage
logInfo("Marking " + failedStage + " for resubmision due to a fetch failure")
// Mark the map whose fetch failed as broken in the map stage
val mapStage = shuffleToMapStage(shuffleId)
mapStage.removeOutputLoc(mapId, serverUri)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, serverUri)
logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission")
failed += mapStage
// Remember that a fetch failed now; this is used to resubmit the broken
// stages later, after a small wait (to give other tasks the chance to fail)
lastFetchFailureTime = time
// TODO: If there are a lot of fetch failures on the same node, maybe mark all
// outputs on the node as dead.
case _ =>
// Non-fetch failure -- probably a bug in the job, so bail out
throw new SparkException("Task failed: " + evt.task + ", reason: " + evt.reason)
// TODO: Cancel all tasks that are still running
} // end if (evt != null)
// If fetches have failed recently and we've waited for the right timeout,
// resubmit all the failed stages
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
logInfo("Resubmitting failed stages")
updateCacheLocs()
for (stage <- failed) {
submitStage(stage)
}
failed.clear()
}
} // end if (evt != null)
// If fetches have failed recently and we've waited for the right timeout,
// resubmit all the failed stages
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
logInfo("Resubmitting failed stages")
updateCacheLocs()
for (stage <- failed) {
submitStage(stage)
}
failed.clear()
}
return results
}
return results
}
def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment