diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala
index cb78c2b58273da548b8aa2be17e8fb6a36ae1490..40680a625f0e7e671ccbbe76e07b2700971cd96b 100644
--- a/src/scala/spark/MesosScheduler.scala
+++ b/src/scala/spark/MesosScheduler.scala
@@ -88,18 +88,13 @@ extends MScheduler with spark.Scheduler with Logging
         this.activeJobsQueue += myJob
       }
       driver.reviveOffers();
-      myJob.join();
+      return myJob.join();
     } finally {
       this.synchronized {
         this.activeJobs.remove(myJob.jobId)
         this.activeJobsQueue.dequeueAll(x => (x == myJob))
       }
     }
-
-    if (myJob.errorHappened)
-      throw new SparkException(myJob.errorMessage, myJob.errorCode)
-    else
-      return myJob.results
   }
 
   override def registered(d: SchedulerDriver, frameworkId: String) {
diff --git a/src/scala/spark/SimpleJob.scala b/src/scala/spark/SimpleJob.scala
index 425dbe63667bf6dded73c4e3b1d4d9ada968b5f9..a8544e4474bb8f4b868b729853afbf8525677135 100644
--- a/src/scala/spark/SimpleJob.scala
+++ b/src/scala/spark/SimpleJob.scala
@@ -18,27 +18,28 @@ extends Job with Logging
   // Maximum time to wait to run a task in a preferred location (in ms)
   val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
 
-  // CPUs and memory to claim per task from Mesos
+  // CPUs and memory to request per task
   val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
   val MEM_PER_TASK = System.getProperty("spark.task.mem", "512").toInt
 
+  // Maximum times a task is allowed to fail before failing the job
+  val MAX_TASK_FAILURES = 4
+
   val callingThread = currentThread
   val numTasks = tasks.length
   val results = new Array[T](numTasks)
   val launched = new Array[Boolean](numTasks)
   val finished = new Array[Boolean](numTasks)
+  val numFailures = new Array[Int](numTasks)
   val tidToIndex = HashMap[Int, Int]()
 
   var allFinished = false
   val joinLock = new Object() // Used to wait for all tasks to finish
 
-  var errorHappened = false
-  var errorCode = 0
-  var errorMessage = ""
-
   var tasksLaunched = 0
   var tasksFinished = 0
 
+  // Last time when we launched a preferred task (for delay scheduling)
   var lastPreferredLaunchTime = System.currentTimeMillis
 
   // Queue of pending tasks for each node
@@ -47,6 +48,10 @@ extends Job with Logging
   // Queue containing all pending tasks
   val allPendingTasks = new Queue[Int]
 
+  // Did the job fail?
+  var failed = false
+  var causeOfFailure = ""
+
   for (i <- 0 until numTasks) {
     addPendingTask(i)
   }
@@ -58,6 +63,7 @@ extends Job with Logging
     }
   }
 
+  // Mark the job as finished and wake up any threads waiting on it
   def setAllFinished() {
     joinLock.synchronized {
       allFinished = true
@@ -65,10 +71,17 @@ extends Job with Logging
     }
   }
 
-  def join() {
+  // Wait until the job finishes and return its results
+  def join(): Array[T] = {
     joinLock.synchronized {
-      while (!allFinished)
+      while (!allFinished) {
         joinLock.wait()
+      }
+      if (failed) {
+        throw new SparkException(causeOfFailure)
+      } else {
+        return results
+      }
     }
   }
 
@@ -193,6 +206,17 @@ extends Job with Logging
       tasksLaunched -= 1
       // Re-enqueue the task as pending
       addPendingTask(index)
+      // Mark it as failed
+      if (status.getState == TaskState.TASK_FAILED ||
+          status.getState == TaskState.TASK_LOST) {
+        numFailures(index) += 1
+        if (numFailures(index) > MAX_TASK_FAILURES) {
+          logError("Task %d:%d failed more than %d times; aborting job".format(
+            jobId, index, MAX_TASK_FAILURES))
+          abort("Task %d failed more than %d times".format(
+            index, MAX_TASK_FAILURES))
+        }
+      }
     } else {
       logInfo("Ignoring task-lost event for TID " + tid +
         " because task " + index + " is already finished")
@@ -201,10 +225,16 @@ extends Job with Logging
 
   def error(code: Int, message: String) {
     // Save the error message
-    errorHappened = true
-    errorCode = code
-    errorMessage = message
-    // Indicate to caller thread that we're done
-    setAllFinished()
+    abort("Mesos error: %s (error code: %d)".format(message, code))
+  }
+
+  def abort(message: String) {
+    joinLock.synchronized {
+      failed = true
+      causeOfFailure = message
+      // TODO: Kill running tasks if we were not terminated due to a Mesos error
+      // Indicate to any joining thread that we're done
+      setAllFinished()
+    }
   }
 }
diff --git a/src/scala/spark/SparkException.scala b/src/scala/spark/SparkException.scala
index 7257bf7b0cd1417ee42fcaa4714165f9a6c6d5c2..6f9be1a94fc2071209e0c02e0ef3009cc86abded 100644
--- a/src/scala/spark/SparkException.scala
+++ b/src/scala/spark/SparkException.scala
@@ -1,7 +1,3 @@
 package spark
 
-class SparkException(message: String) extends Exception(message) {
-  def this(message: String, errorCode: Int) {
-    this("%s (error code: %d)".format(message, errorCode))
-  }
-}
+class SparkException(message: String) extends Exception(message) {}