diff --git a/src/scala/spark/SimpleJob.scala b/src/scala/spark/SimpleJob.scala
index 9664a4457807dc1eecccd0199b772d23ef93a6f5..425dbe63667bf6dded73c4e3b1d4d9ada968b5f9 100644
--- a/src/scala/spark/SimpleJob.scala
+++ b/src/scala/spark/SimpleJob.scala
@@ -3,6 +3,7 @@ package spark
 import java.util.{HashMap => JHashMap}
 
 import scala.collection.mutable.HashMap
+import scala.collection.mutable.Queue
 
 import mesos._
 
@@ -17,6 +18,10 @@ extends Job with Logging
   // Maximum time to wait to run a task in a preferred location (in ms)
   val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
 
+  // CPUs and memory to claim per task from Mesos
+  val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
+  val MEM_PER_TASK = System.getProperty("spark.task.mem", "512").toInt
+
   val callingThread = currentThread
   val numTasks = tasks.length
   val results = new Array[T](numTasks)
@@ -25,7 +30,7 @@ extends Job with Logging
   val tidToIndex = HashMap[Int, Int]()
 
   var allFinished = false
-  val joinLock = new Object()
+  val joinLock = new Object() // Used to wait for all tasks to finish
 
   var errorHappened = false
   var errorCode = 0
@@ -33,10 +38,25 @@ extends Job with Logging
 
   var tasksLaunched = 0
   var tasksFinished = 0
+
   var lastPreferredLaunchTime = System.currentTimeMillis
 
-  val cpusPerTask = System.getProperty("spark.task.cpus", "1").toInt
-  val memPerTask = System.getProperty("spark.task.mem", "512").toInt
+  // Queue of pending tasks for each node
+  val pendingTasksForNode = new HashMap[String, Queue[Int]]
+
+  // Queue containing all pending tasks
+  val allPendingTasks = new Queue[Int]
+
+  for (i <- 0 until numTasks) {
+    addPendingTask(i)
+  }
+
+  def addPendingTask(index: Int) {
+    allPendingTasks += index
+    for (host <- tasks(index).preferredLocations) {
+      pendingTasksForNode(host) += index
+    }
+  }
 
   def setAllFinished() {
     joinLock.synchronized {
@@ -52,41 +72,74 @@ extends Job with Logging
     }
   }
 
+  def getPendingTasksForNode(host: String): Queue[Int] = {
+    pendingTasksForNode.getOrElse(host, Queue())
+  }
+
+  // Dequeue a pending task from the given queue and return its index.
+  // Return None if the queue is empty.
+  def findTaskFromQueue(queue: Queue[Int]): Option[Int] = {
+    while (!queue.isEmpty) {
+      val index = queue.dequeue
+      if (!launched(index) && !finished(index)) {
+        return Some(index)
+      }
+    }
+    return None
+  }
+
+  // Dequeue a pending task for a given node and return its index.
+  // If localOnly is set to false, allow non-local tasks as well.
+  def findTask(host: String, localOnly: Boolean): Option[Int] = {
+    findTaskFromQueue(getPendingTasksForNode(host)) match {
+      case Some(task) => Some(task)
+      case None =>
+        if (localOnly) None
+        else findTaskFromQueue(allPendingTasks)
+    }
+  }
+
+  def isPreferredLocation(task: Task[T], host: String): Boolean = {
+    val locs = task.preferredLocations
+    return (locs.contains(host) || locs.isEmpty)
+  }
+
   def slaveOffer(offer: SlaveOffer, availableCpus: Int, availableMem: Int)
       : Option[TaskDescription] = {
-    if (tasksLaunched < numTasks) {
-      var checkPrefVals: Array[Boolean] = Array(true)
+    if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK &&
+        availableMem >= MEM_PER_TASK) {
       val time = System.currentTimeMillis
-      if (time - lastPreferredLaunchTime > LOCALITY_WAIT)
-        checkPrefVals = Array(true, false) // Allow non-preferred tasks
-      if ((availableCpus < cpusPerTask) || (availableMem < memPerTask))
-        return None
-      for (checkPref <- checkPrefVals; i <- 0 until numTasks) {
-        if (!launched(i) && (!checkPref ||
-            tasks(i).preferredLocations.contains(offer.getHost) ||
-            tasks(i).preferredLocations.isEmpty))
-        {
+      val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
+      val host = offer.getHost
+      findTask(host, localOnly) match {
+        case Some(index) => {
+          val task = tasks(index)
           val taskId = sched.newTaskId()
-          sched.taskIdToJobId(taskId) = jobId
-          tidToIndex(taskId) = i
-          val preferred = if(checkPref) "preferred" else "non-preferred"
+          // Figure out whether the task's location is preferred
+          val preferred = isPreferredLocation(task, host)
+          val prefStr = if(preferred) "preferred" else "non-preferred"
           val message =
             "Starting task %d:%d as TID %s on slave %s: %s (%s)".format(
-              i, jobId, taskId, offer.getSlaveId, offer.getHost, preferred)
+              index, jobId, taskId, offer.getSlaveId, host, prefStr)
           logInfo(message)
-          tasks(i).markStarted(offer)
-          launched(i) = true
+          // Do various bookkeeping
+          sched.taskIdToJobId(taskId) = jobId
+          tidToIndex(taskId) = index
+          task.markStarted(offer)
+          launched(index) = true
           tasksLaunched += 1
-          if (checkPref)
+          if (preferred)
             lastPreferredLaunchTime = time
+          // Create and return the Mesos task object
           val params = new JHashMap[String, String]
-          params.put("cpus", "" + cpusPerTask)
-          params.put("mem", "" + memPerTask)
-          val serializedTask = Utils.serialize(tasks(i))
+          params.put("cpus", "" + CPUS_PER_TASK)
+          params.put("mem", "" + MEM_PER_TASK)
+          val serializedTask = Utils.serialize(task)
           logDebug("Serialized size: " + serializedTask.size)
           return Some(new TaskDescription(taskId, offer.getSlaveId,
             "task_" + taskId, params, serializedTask))
         }
+        case _ =>
       }
     }
     return None
@@ -138,6 +191,8 @@ extends Job with Logging
       launched(index) = false
       sched.taskIdToJobId.remove(tid)
       tasksLaunched -= 1
+      // Re-enqueue the task as pending
+      addPendingTask(index)
     } else {
       logInfo("Ignoring task-lost event for TID " + tid +
         " because task " + index + " is already finished")