diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala index cb78c2b58273da548b8aa2be17e8fb6a36ae1490..40680a625f0e7e671ccbbe76e07b2700971cd96b 100644 --- a/src/scala/spark/MesosScheduler.scala +++ b/src/scala/spark/MesosScheduler.scala @@ -88,18 +88,13 @@ extends MScheduler with spark.Scheduler with Logging this.activeJobsQueue += myJob } driver.reviveOffers(); - myJob.join(); + return myJob.join(); } finally { this.synchronized { this.activeJobs.remove(myJob.jobId) this.activeJobsQueue.dequeueAll(x => (x == myJob)) } } - - if (myJob.errorHappened) - throw new SparkException(myJob.errorMessage, myJob.errorCode) - else - return myJob.results } override def registered(d: SchedulerDriver, frameworkId: String) { diff --git a/src/scala/spark/SimpleJob.scala b/src/scala/spark/SimpleJob.scala index 425dbe63667bf6dded73c4e3b1d4d9ada968b5f9..a8544e4474bb8f4b868b729853afbf8525677135 100644 --- a/src/scala/spark/SimpleJob.scala +++ b/src/scala/spark/SimpleJob.scala @@ -18,27 +18,28 @@ extends Job with Logging // Maximum time to wait to run a task in a preferred location (in ms) val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong - // CPUs and memory to claim per task from Mesos + // CPUs and memory to request per task val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt val MEM_PER_TASK = System.getProperty("spark.task.mem", "512").toInt + // Maximum times a task is allowed to fail before failing the job + val MAX_TASK_FAILURES = 4 + val callingThread = currentThread val numTasks = tasks.length val results = new Array[T](numTasks) val launched = new Array[Boolean](numTasks) val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) val tidToIndex = HashMap[Int, Int]() var allFinished = false val joinLock = new Object() // Used to wait for all tasks to finish - var errorHappened = false - var errorCode = 0 - var errorMessage = "" - var tasksLaunched = 0 var tasksFinished = 0 + // Last time when we launched a preferred task (for delay scheduling) var lastPreferredLaunchTime = System.currentTimeMillis // Queue of pending tasks for each node @@ -47,6 +48,10 @@ extends Job with Logging // Queue containing all pending tasks val allPendingTasks = new Queue[Int] + // Did the job fail? + var failed = false + var causeOfFailure = "" + for (i <- 0 until numTasks) { addPendingTask(i) } @@ -58,6 +63,7 @@ extends Job with Logging } } + // Mark the job as finished and wake up any threads waiting on it def setAllFinished() { joinLock.synchronized { allFinished = true @@ -65,10 +71,17 @@ extends Job with Logging } } - def join() { + // Wait until the job finishes and return its results + def join(): Array[T] = { joinLock.synchronized { - while (!allFinished) + while (!allFinished) { joinLock.wait() + } + if (failed) { + throw new SparkException(causeOfFailure) + } else { + return results + } } } @@ -193,6 +206,17 @@ extends Job with Logging tasksLaunched -= 1 // Re-enqueue the task as pending addPendingTask(index) + // Mark it as failed + if (status.getState == TaskState.TASK_FAILED || + status.getState == TaskState.TASK_LOST) { + numFailures(index) += 1 + if (numFailures(index) > MAX_TASK_FAILURES) { + logError("Task %d:%d failed more than %d times; aborting job".format( + jobId, index, MAX_TASK_FAILURES)) + abort("Task %d failed more than %d times".format( + index, MAX_TASK_FAILURES)) + } + } } else { logInfo("Ignoring task-lost event for TID " + tid + " because task " + index + " is already finished") @@ -201,10 +225,16 @@ extends Job with Logging def error(code: Int, message: String) { // Save the error message - errorHappened = true - errorCode = code - errorMessage = message - // Indicate to caller thread that we're done - setAllFinished() + abort("Mesos error: %s (error code: %d)".format(message, code)) + } + + def abort(message: String) { + joinLock.synchronized { + failed = true + causeOfFailure = message + // TODO: Kill running tasks if we were not terminated due to a Mesos error + // Indicate to any joining thread that we're done + setAllFinished() + } } } diff --git a/src/scala/spark/SparkException.scala b/src/scala/spark/SparkException.scala index 7257bf7b0cd1417ee42fcaa4714165f9a6c6d5c2..6f9be1a94fc2071209e0c02e0ef3009cc86abded 100644 --- a/src/scala/spark/SparkException.scala +++ b/src/scala/spark/SparkException.scala @@ -1,7 +1,3 @@ package spark -class SparkException(message: String) extends Exception(message) { - def this(message: String, errorCode: Int) { - this("%s (error code: %d)".format(message, errorCode)) - } -} +class SparkException(message: String) extends Exception(message) {}