Skip to content
Snippets Groups Projects
Commit aa8ccec3 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Abort jobs if a task fails more than a limited number of times

parent 57a77842
No related branches found
No related tags found
No related merge requests found
...@@ -88,18 +88,13 @@ extends MScheduler with spark.Scheduler with Logging ...@@ -88,18 +88,13 @@ extends MScheduler with spark.Scheduler with Logging
this.activeJobsQueue += myJob this.activeJobsQueue += myJob
} }
driver.reviveOffers(); driver.reviveOffers();
myJob.join(); return myJob.join();
} finally { } finally {
this.synchronized { this.synchronized {
this.activeJobs.remove(myJob.jobId) this.activeJobs.remove(myJob.jobId)
this.activeJobsQueue.dequeueAll(x => (x == myJob)) this.activeJobsQueue.dequeueAll(x => (x == myJob))
} }
} }
if (myJob.errorHappened)
throw new SparkException(myJob.errorMessage, myJob.errorCode)
else
return myJob.results
} }
override def registered(d: SchedulerDriver, frameworkId: String) { override def registered(d: SchedulerDriver, frameworkId: String) {
......
...@@ -18,27 +18,28 @@ extends Job with Logging ...@@ -18,27 +18,28 @@ extends Job with Logging
// Maximum time to wait to run a task in a preferred location (in ms) // Maximum time to wait to run a task in a preferred location (in ms)
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
// CPUs and memory to claim per task from Mesos // CPUs and memory to request per task
val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
val MEM_PER_TASK = System.getProperty("spark.task.mem", "512").toInt val MEM_PER_TASK = System.getProperty("spark.task.mem", "512").toInt
// Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = 4
val callingThread = currentThread val callingThread = currentThread
val numTasks = tasks.length val numTasks = tasks.length
val results = new Array[T](numTasks) val results = new Array[T](numTasks)
val launched = new Array[Boolean](numTasks) val launched = new Array[Boolean](numTasks)
val finished = new Array[Boolean](numTasks) val finished = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
val tidToIndex = HashMap[Int, Int]() val tidToIndex = HashMap[Int, Int]()
var allFinished = false var allFinished = false
val joinLock = new Object() // Used to wait for all tasks to finish val joinLock = new Object() // Used to wait for all tasks to finish
var errorHappened = false
var errorCode = 0
var errorMessage = ""
var tasksLaunched = 0 var tasksLaunched = 0
var tasksFinished = 0 var tasksFinished = 0
// Last time when we launched a preferred task (for delay scheduling)
var lastPreferredLaunchTime = System.currentTimeMillis var lastPreferredLaunchTime = System.currentTimeMillis
// Queue of pending tasks for each node // Queue of pending tasks for each node
...@@ -47,6 +48,10 @@ extends Job with Logging ...@@ -47,6 +48,10 @@ extends Job with Logging
// Queue containing all pending tasks // Queue containing all pending tasks
val allPendingTasks = new Queue[Int] val allPendingTasks = new Queue[Int]
// Did the job fail?
var failed = false
var causeOfFailure = ""
for (i <- 0 until numTasks) { for (i <- 0 until numTasks) {
addPendingTask(i) addPendingTask(i)
} }
...@@ -58,6 +63,7 @@ extends Job with Logging ...@@ -58,6 +63,7 @@ extends Job with Logging
} }
} }
// Mark the job as finished and wake up any threads waiting on it
def setAllFinished() { def setAllFinished() {
joinLock.synchronized { joinLock.synchronized {
allFinished = true allFinished = true
...@@ -65,10 +71,17 @@ extends Job with Logging ...@@ -65,10 +71,17 @@ extends Job with Logging
} }
} }
def join() { // Wait until the job finishes and return its results
def join(): Array[T] = {
joinLock.synchronized { joinLock.synchronized {
while (!allFinished) while (!allFinished) {
joinLock.wait() joinLock.wait()
}
if (failed) {
throw new SparkException(causeOfFailure)
} else {
return results
}
} }
} }
...@@ -193,6 +206,17 @@ extends Job with Logging ...@@ -193,6 +206,17 @@ extends Job with Logging
tasksLaunched -= 1 tasksLaunched -= 1
// Re-enqueue the task as pending // Re-enqueue the task as pending
addPendingTask(index) addPendingTask(index)
// Mark it as failed
if (status.getState == TaskState.TASK_FAILED ||
status.getState == TaskState.TASK_LOST) {
numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %d:%d failed more than %d times; aborting job".format(
jobId, index, MAX_TASK_FAILURES))
abort("Task %d failed more than %d times".format(
index, MAX_TASK_FAILURES))
}
}
} else { } else {
logInfo("Ignoring task-lost event for TID " + tid + logInfo("Ignoring task-lost event for TID " + tid +
" because task " + index + " is already finished") " because task " + index + " is already finished")
...@@ -201,10 +225,16 @@ extends Job with Logging ...@@ -201,10 +225,16 @@ extends Job with Logging
def error(code: Int, message: String) { def error(code: Int, message: String) {
// Save the error message // Save the error message
errorHappened = true abort("Mesos error: %s (error code: %d)".format(message, code))
errorCode = code }
errorMessage = message
// Indicate to caller thread that we're done def abort(message: String) {
setAllFinished() joinLock.synchronized {
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
// Indicate to any joining thread that we're done
setAllFinished()
}
} }
} }
package spark package spark
class SparkException(message: String) extends Exception(message) { class SparkException(message: String) extends Exception(message) {}
def this(message: String, errorCode: Int) {
this("%s (error code: %d)".format(message, errorCode))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment