diff --git a/src/scala/spark/Job.scala b/src/scala/spark/Job.scala
new file mode 100644
index 0000000000000000000000000000000000000000..6b01307adcc5e06d0e65ffa25f461b0406270c86
--- /dev/null
+++ b/src/scala/spark/Job.scala
@@ -0,0 +1,16 @@
+package spark
+
+import mesos._
+
+/**
+ * Trait representing a parallel job in MesosScheduler. Schedules the
+ * job by implementing various callbacks.
+ */
+trait Job {
+  def slaveOffer(s: SlaveOffer, availableCpus: Int, availableMem: Int)
+    : Option[TaskDescription]
+
+  def statusUpdate(t: TaskStatus): Unit
+
+  def error(code: Int, message: String): Unit
+}
diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala
index ecb95d7c0bc09fc98d690b7a44d62e375f6c2fba..cb78c2b58273da548b8aa2be17e8fb6a36ae1490 100644
--- a/src/scala/spark/MesosScheduler.scala
+++ b/src/scala/spark/MesosScheduler.scala
@@ -1,13 +1,16 @@
 package spark
 
 import java.io.File
+import java.util.{ArrayList => JArrayList}
+import java.util.{List => JList}
+import java.util.{HashMap => JHashMap}
 
 import scala.collection.mutable.Map
 import scala.collection.mutable.Queue
 import scala.collection.mutable.HashMap
 import scala.collection.JavaConversions._
 
-import mesos.{Scheduler => NScheduler}
+import mesos.{Scheduler => MScheduler}
 import mesos._
 
 // The main Scheduler implementation, which talks to Mesos. Clients are expected
@@ -23,22 +26,20 @@ import mesos._
 //    all the offers to the Job and have it load-balance.
 private class MesosScheduler(
   master: String, frameworkName: String, execArg: Array[Byte])
-extends NScheduler with spark.Scheduler with Logging
+extends MScheduler with spark.Scheduler with Logging
 {
-  // Lock used by runTasks to ensure only one thread can be in it
-  val runTasksMutex = new Object()
-
   // Lock used to wait for  scheduler to be registered
   var isRegistered = false
   val registeredLock = new Object()
 
-  // Current callback object (may be null)
-  var activeJobsQueue = new Queue[Int]
   var activeJobs = new HashMap[Int, Job]
-  private var nextJobId = 0
+  var activeJobsQueue = new Queue[Job]
+
   private[spark] var taskIdToJobId = new HashMap[Int, Int]
+
+  private var nextJobId = 0
   
-  def newJobId(): Int = {
+  def newJobId(): Int = this.synchronized {
     val id = nextJobId
     nextJobId += 1
     return id
@@ -60,9 +61,9 @@ extends NScheduler with spark.Scheduler with Logging
     new Thread("Spark scheduler") {
       setDaemon(true)
       override def run {
-        val ns = MesosScheduler.this
-        ns.driver = new MesosSchedulerDriver(ns, master)
-        ns.driver.run()
+        val sched = MesosScheduler.this
+        sched.driver = new MesosSchedulerDriver(sched, master)
+        sched.driver.run()
       }
     }.start
   }
@@ -72,25 +73,26 @@ extends NScheduler with spark.Scheduler with Logging
   override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo =
     new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg)
 
+  /**
+   * The primary means to submit a job to the scheduler. Given a list of tasks,
+   * runs them and returns an array of the results.
+   */
   override def runTasks[T: ClassManifest](tasks: Array[Task[T]]): Array[T] = {
-    var jobId = 0
     waitForRegister()
-    this.synchronized {
-      jobId = newJobId()
-    }
+    val jobId = newJobId()
     val myJob = new SimpleJob(this, tasks, jobId)
 
     try {
       this.synchronized {
         this.activeJobs(myJob.jobId) = myJob
-        this.activeJobsQueue += myJob.jobId
+        this.activeJobsQueue += myJob
       }
       driver.reviveOffers();
       myJob.join();
     } finally {
       this.synchronized {
         this.activeJobs.remove(myJob.jobId)
-        this.activeJobsQueue.dequeueAll(x => (x == myJob.jobId))
+        this.activeJobsQueue.dequeueAll(x => (x == myJob))
       }
     }
 
@@ -116,35 +118,34 @@ extends NScheduler with spark.Scheduler with Logging
   }
 
   override def resourceOffer(
-      d: SchedulerDriver, oid: String, offers: java.util.List[SlaveOffer]) {
+      d: SchedulerDriver, oid: String, offers: JList[SlaveOffer]) {
     synchronized {
-      val tasks = new java.util.ArrayList[TaskDescription]
+      val tasks = new JArrayList[TaskDescription]
       val availableCpus = offers.map(_.getParams.get("cpus").toInt)
       val availableMem = offers.map(_.getParams.get("mem").toInt)
-      var launchedTask = true
-      for (jobId <- activeJobsQueue) {
-        launchedTask = true
-        while (launchedTask) {
+      var launchedTask = false
+      for (job <- activeJobsQueue) {
+        do {
           launchedTask = false
           for (i <- 0 until offers.size.toInt) {
             try {
-              activeJobs(jobId).slaveOffer(offers.get(i), availableCpus(i), availableMem(i)) match {
+              job.slaveOffer(offers(i), availableCpus(i), availableMem(i)) match {
                 case Some(task) =>
                   tasks.add(task)
                   availableCpus(i) -= task.getParams.get("cpus").toInt
                   availableMem(i) -= task.getParams.get("mem").toInt
-                  launchedTask = launchedTask || true
+                  launchedTask = true
                 case None => {}
               }
             } catch {
               case e: Exception => logError("Exception in resourceOffer", e)
             }
           }
-        }
+        } while (launchedTask)
       }
-      val params = new java.util.HashMap[String, String]
+      val params = new JHashMap[String, String]
       params.put("timeout", "1")
-      d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout
+      d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout?
     }
   }
 
@@ -167,8 +168,10 @@ extends NScheduler with spark.Scheduler with Logging
   }
 
   override def error(d: SchedulerDriver, code: Int, message: String) {
+    logError("Mesos error: %s (error code: %d)".format(message, code))
     synchronized {
       if (activeJobs.size > 0) {
+        // Have each job throw a SparkException with the error
         for ((jobId, activeJob) <- activeJobs) {
           try {
             activeJob.error(code, message)
@@ -177,7 +180,9 @@ extends NScheduler with spark.Scheduler with Logging
           }
         }
       } else {
-        logError("Mesos error: %s (error code: %d)".format(message, code))
+        // No jobs are active but we still got an error. Just exit since this
+        // must mean the error is during registration.
+        // It might be good to do something smarter here in the future.
         System.exit(1)
       }
     }
@@ -191,156 +196,3 @@ extends NScheduler with spark.Scheduler with Logging
   // TODO: query Mesos for number of cores
   override def numCores() = System.getProperty("spark.default.parallelism", "2").toInt
 }
-
-
-// Trait representing an object that manages a parallel operation by
-// implementing various scheduler callbacks.
-trait Job {
-  def slaveOffer(s: SlaveOffer, availableCpus: Int, availableMem: Int): Option[TaskDescription]
-  def statusUpdate(t: TaskStatus): Unit
-  def error(code: Int, message: String): Unit
-}
-
-
-class SimpleJob[T: ClassManifest](
-  sched: MesosScheduler, tasks: Array[Task[T]], val jobId: Int)
-extends Job with Logging
-{
-  // Maximum time to wait to run a task in a preferred location (in ms)
-  val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
-
-  val callingThread = currentThread
-  val numTasks = tasks.length
-  val results = new Array[T](numTasks)
-  val launched = new Array[Boolean](numTasks)
-  val finished = new Array[Boolean](numTasks)
-  val tidToIndex = Map[Int, Int]()
-
-  var allFinished = false
-  val joinLock = new Object()
-
-  var errorHappened = false
-  var errorCode = 0
-  var errorMessage = ""
-
-  var tasksLaunched = 0
-  var tasksFinished = 0
-  var lastPreferredLaunchTime = System.currentTimeMillis
-
-  def setAllFinished() {
-    joinLock.synchronized {
-      allFinished = true
-      joinLock.notifyAll()
-    }
-  }
-
-  def join() {
-    joinLock.synchronized {
-      while (!allFinished)
-        joinLock.wait()
-    }
-  }
-
-  def slaveOffer(offer: SlaveOffer, availableCpus: Int, availableMem: Int): Option[TaskDescription] = {
-    if (tasksLaunched < numTasks) {
-      var checkPrefVals: Array[Boolean] = Array(true)
-      val time = System.currentTimeMillis
-      if (time - lastPreferredLaunchTime > LOCALITY_WAIT)
-        checkPrefVals = Array(true, false) // Allow non-preferred tasks
-      // TODO: Make desiredCpus and desiredMem configurable
-      val desiredCpus = 1
-      val desiredMem = 500
-      if ((availableCpus < desiredCpus) || (availableMem < desiredMem))
-        return None
-      for (checkPref <- checkPrefVals; i <- 0 until numTasks) {
-        if (!launched(i) && (!checkPref ||
-            tasks(i).preferredLocations.contains(offer.getHost) ||
-            tasks(i).preferredLocations.isEmpty))
-        {
-          val taskId = sched.newTaskId()
-          sched.taskIdToJobId(taskId) = jobId
-          tidToIndex(taskId) = i
-          val preferred = if(checkPref) "preferred" else "non-preferred"
-          val message =
-            "Starting task %d as jobId %d, TID %s on slave %s: %s (%s)".format(
-              i, jobId, taskId, offer.getSlaveId, offer.getHost, preferred)
-          logInfo(message)
-          tasks(i).markStarted(offer)
-          launched(i) = true
-          tasksLaunched += 1
-          if (checkPref)
-            lastPreferredLaunchTime = time
-          val params = new java.util.HashMap[String, String]
-          params.put("cpus", "" + desiredCpus)
-          params.put("mem", "" + desiredMem)
-          val serializedTask = Utils.serialize(tasks(i))
-          //logInfo("Serialized size: " + serializedTask.size)
-          return Some(new TaskDescription(taskId, offer.getSlaveId,
-            "task_" + taskId, params, serializedTask))
-        }
-      }
-    }
-    return None
-  }
-
-  def statusUpdate(status: TaskStatus) {
-    status.getState match {
-      case TaskState.TASK_FINISHED =>
-        taskFinished(status)
-      case TaskState.TASK_LOST =>
-        taskLost(status)
-      case TaskState.TASK_FAILED =>
-        taskLost(status)
-      case TaskState.TASK_KILLED =>
-        taskLost(status)
-      case _ =>
-    }
-  }
-
-  def taskFinished(status: TaskStatus) {
-    val tid = status.getTaskId
-    val index = tidToIndex(tid)
-    if (!finished(index)) {
-      tasksFinished += 1
-      logInfo("Finished job %d TID %d (progress: %d/%d)".format(
-        jobId, tid, tasksFinished, numTasks))
-      // Deserialize task result
-      val result = Utils.deserialize[TaskResult[T]](status.getData)
-      results(index) = result.value
-      // Update accumulators
-      Accumulators.add(callingThread, result.accumUpdates)
-      // Mark finished and stop if we've finished all the tasks
-      finished(index) = true
-      // Remove TID -> jobId mapping from sched
-      sched.taskIdToJobId.remove(tid)
-      if (tasksFinished == numTasks)
-        setAllFinished()
-    } else {
-      logInfo("Ignoring task-finished event for TID " + tid +
-        " because task " + index + " is already finished")
-    }
-  }
-
-  def taskLost(status: TaskStatus) {
-    val tid = status.getTaskId
-    val index = tidToIndex(tid)
-    if (!finished(index)) {
-      logInfo("Lost job " + jobId + " TID " + tid)
-      launched(index) = false
-      sched.taskIdToJobId.remove(tid)
-      tasksLaunched -= 1
-    } else {
-      logInfo("Ignoring task-lost event for TID " + tid +
-        " because task " + index + " is already finished")
-    }
-  }
-
-  def error(code: Int, message: String) {
-    // Save the error message
-    errorHappened = true
-    errorCode = code
-    errorMessage = message
-    // Indicate to caller thread that we're done
-    setAllFinished()
-  }
-}
diff --git a/src/scala/spark/SimpleJob.scala b/src/scala/spark/SimpleJob.scala
new file mode 100644
index 0000000000000000000000000000000000000000..9664a4457807dc1eecccd0199b772d23ef93a6f5
--- /dev/null
+++ b/src/scala/spark/SimpleJob.scala
@@ -0,0 +1,155 @@
+package spark
+
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.mutable.HashMap
+
+import mesos._
+
+
+/**
+ * A simple implementation of Job that just runs each task in an array.
+ */
+class SimpleJob[T: ClassManifest](
+  sched: MesosScheduler, tasks: Array[Task[T]], val jobId: Int)
+extends Job with Logging
+{
+  // Maximum time to wait to run a task in a preferred location (in ms)
+  val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
+
+  val callingThread = currentThread
+  val numTasks = tasks.length
+  val results = new Array[T](numTasks)
+  val launched = new Array[Boolean](numTasks)
+  val finished = new Array[Boolean](numTasks)
+  val tidToIndex = HashMap[Int, Int]()
+
+  var allFinished = false
+  val joinLock = new Object()
+
+  var errorHappened = false
+  var errorCode = 0
+  var errorMessage = ""
+
+  var tasksLaunched = 0
+  var tasksFinished = 0
+  var lastPreferredLaunchTime = System.currentTimeMillis
+
+  val cpusPerTask = System.getProperty("spark.task.cpus", "1").toInt
+  val memPerTask = System.getProperty("spark.task.mem", "512").toInt
+
+  def setAllFinished() {
+    joinLock.synchronized {
+      allFinished = true
+      joinLock.notifyAll()
+    }
+  }
+
+  def join() {
+    joinLock.synchronized {
+      while (!allFinished)
+        joinLock.wait()
+    }
+  }
+
+  def slaveOffer(offer: SlaveOffer, availableCpus: Int, availableMem: Int)
+      : Option[TaskDescription] = {
+    if (tasksLaunched < numTasks) {
+      var checkPrefVals: Array[Boolean] = Array(true)
+      val time = System.currentTimeMillis
+      if (time - lastPreferredLaunchTime > LOCALITY_WAIT)
+        checkPrefVals = Array(true, false) // Allow non-preferred tasks
+      if ((availableCpus < cpusPerTask) || (availableMem < memPerTask))
+        return None
+      for (checkPref <- checkPrefVals; i <- 0 until numTasks) {
+        if (!launched(i) && (!checkPref ||
+            tasks(i).preferredLocations.contains(offer.getHost) ||
+            tasks(i).preferredLocations.isEmpty))
+        {
+          val taskId = sched.newTaskId()
+          sched.taskIdToJobId(taskId) = jobId
+          tidToIndex(taskId) = i
+          val preferred = if(checkPref) "preferred" else "non-preferred"
+          val message =
+            "Starting task %d:%d as TID %s on slave %s: %s (%s)".format(
+              i, jobId, taskId, offer.getSlaveId, offer.getHost, preferred)
+          logInfo(message)
+          tasks(i).markStarted(offer)
+          launched(i) = true
+          tasksLaunched += 1
+          if (checkPref)
+            lastPreferredLaunchTime = time
+          val params = new JHashMap[String, String]
+          params.put("cpus", "" + cpusPerTask)
+          params.put("mem", "" + memPerTask)
+          val serializedTask = Utils.serialize(tasks(i))
+          logDebug("Serialized size: " + serializedTask.size)
+          return Some(new TaskDescription(taskId, offer.getSlaveId,
+            "task_" + taskId, params, serializedTask))
+        }
+      }
+    }
+    return None
+  }
+
+  def statusUpdate(status: TaskStatus) {
+    status.getState match {
+      case TaskState.TASK_FINISHED =>
+        taskFinished(status)
+      case TaskState.TASK_LOST =>
+        taskLost(status)
+      case TaskState.TASK_FAILED =>
+        taskLost(status)
+      case TaskState.TASK_KILLED =>
+        taskLost(status)
+      case _ =>
+    }
+  }
+
+  def taskFinished(status: TaskStatus) {
+    val tid = status.getTaskId
+    val index = tidToIndex(tid)
+    if (!finished(index)) {
+      tasksFinished += 1
+      logInfo("Finished TID %d (progress: %d/%d)".format(
+        tid, tasksFinished, numTasks))
+      // Deserialize task result
+      val result = Utils.deserialize[TaskResult[T]](status.getData)
+      results(index) = result.value
+      // Update accumulators
+      Accumulators.add(callingThread, result.accumUpdates)
+      // Mark finished and stop if we've finished all the tasks
+      finished(index) = true
+      // Remove TID -> jobId mapping from sched
+      sched.taskIdToJobId.remove(tid)
+      if (tasksFinished == numTasks)
+        setAllFinished()
+    } else {
+      logInfo("Ignoring task-finished event for TID " + tid +
+        " because task " + index + " is already finished")
+    }
+  }
+
+  def taskLost(status: TaskStatus) {
+    val tid = status.getTaskId
+    val index = tidToIndex(tid)
+    if (!finished(index)) {
+      logInfo("Lost TID %d (task %d:%d)".format(tid, jobId, index))
+      launched(index) = false
+      sched.taskIdToJobId.remove(tid)
+      tasksLaunched -= 1
+    } else {
+      logInfo("Ignoring task-lost event for TID " + tid +
+        " because task " + index + " is already finished")
+    }
+  }
+
+  def error(code: Int, message: String) {
+    // Save the error message
+    errorHappened = true
+    errorCode = code
+    errorMessage = message
+    // Indicate to caller thread that we're done
+    setAllFinished()
+  }
+}