diff --git a/src/scala/spark/SimpleJob.scala b/src/scala/spark/SimpleJob.scala index 9664a4457807dc1eecccd0199b772d23ef93a6f5..425dbe63667bf6dded73c4e3b1d4d9ada968b5f9 100644 --- a/src/scala/spark/SimpleJob.scala +++ b/src/scala/spark/SimpleJob.scala @@ -3,6 +3,7 @@ package spark import java.util.{HashMap => JHashMap} import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue import mesos._ @@ -17,6 +18,10 @@ extends Job with Logging // Maximum time to wait to run a task in a preferred location (in ms) val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + // CPUs and memory to claim per task from Mesos + val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt + val MEM_PER_TASK = System.getProperty("spark.task.mem", "512").toInt + val callingThread = currentThread val numTasks = tasks.length val results = new Array[T](numTasks) @@ -25,7 +30,7 @@ extends Job with Logging val tidToIndex = HashMap[Int, Int]() var allFinished = false - val joinLock = new Object() + val joinLock = new Object() // Used to wait for all tasks to finish var errorHappened = false var errorCode = 0 @@ -33,10 +38,25 @@ extends Job with Logging var tasksLaunched = 0 var tasksFinished = 0 + var lastPreferredLaunchTime = System.currentTimeMillis - val cpusPerTask = System.getProperty("spark.task.cpus", "1").toInt - val memPerTask = System.getProperty("spark.task.mem", "512").toInt + // Queue of pending tasks for each node + val pendingTasksForNode = new HashMap[String, Queue[Int]] + + // Queue containing all pending tasks + val allPendingTasks = new Queue[Int] + + for (i <- 0 until numTasks) { + addPendingTask(i) + } + + def addPendingTask(index: Int) { + allPendingTasks += index + for (host <- tasks(index).preferredLocations) { + pendingTasksForNode(host) += index + } + } def setAllFinished() { joinLock.synchronized { @@ -52,41 +72,74 @@ extends Job with Logging } } + def getPendingTasksForNode(host: String): Queue[Int] = { + pendingTasksForNode.getOrElse(host, Queue()) + } + + // Dequeue a pending task from the given queue and return its index. + // Return None if the queue is empty. + def findTaskFromQueue(queue: Queue[Int]): Option[Int] = { + while (!queue.isEmpty) { + val index = queue.dequeue + if (!launched(index) && !finished(index)) { + return Some(index) + } + } + return None + } + + // Dequeue a pending task for a given node and return its index. + // If localOnly is set to false, allow non-local tasks as well. + def findTask(host: String, localOnly: Boolean): Option[Int] = { + findTaskFromQueue(getPendingTasksForNode(host)) match { + case Some(task) => Some(task) + case None => + if (localOnly) None + else findTaskFromQueue(allPendingTasks) + } + } + + def isPreferredLocation(task: Task[T], host: String): Boolean = { + val locs = task.preferredLocations + return (locs.contains(host) || locs.isEmpty) + } + def slaveOffer(offer: SlaveOffer, availableCpus: Int, availableMem: Int) : Option[TaskDescription] = { - if (tasksLaunched < numTasks) { - var checkPrefVals: Array[Boolean] = Array(true) + if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK && + availableMem >= MEM_PER_TASK) { val time = System.currentTimeMillis - if (time - lastPreferredLaunchTime > LOCALITY_WAIT) - checkPrefVals = Array(true, false) // Allow non-preferred tasks - if ((availableCpus < cpusPerTask) || (availableMem < memPerTask)) - return None - for (checkPref <- checkPrefVals; i <- 0 until numTasks) { - if (!launched(i) && (!checkPref || - tasks(i).preferredLocations.contains(offer.getHost) || - tasks(i).preferredLocations.isEmpty)) - { + val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) + val host = offer.getHost + findTask(host, localOnly) match { + case Some(index) => { + val task = tasks(index) val taskId = sched.newTaskId() - sched.taskIdToJobId(taskId) = jobId - tidToIndex(taskId) = i - val preferred = if(checkPref) "preferred" else "non-preferred" + // Figure out whether the task's location is preferred + val preferred = isPreferredLocation(task, host) + val prefStr = if(preferred) "preferred" else "non-preferred" val message = "Starting task %d:%d as TID %s on slave %s: %s (%s)".format( - i, jobId, taskId, offer.getSlaveId, offer.getHost, preferred) + index, jobId, taskId, offer.getSlaveId, host, prefStr) logInfo(message) - tasks(i).markStarted(offer) - launched(i) = true + // Do various bookkeeping + sched.taskIdToJobId(taskId) = jobId + tidToIndex(taskId) = index + task.markStarted(offer) + launched(index) = true tasksLaunched += 1 - if (checkPref) + if (preferred) lastPreferredLaunchTime = time + // Create and return the Mesos task object val params = new JHashMap[String, String] - params.put("cpus", "" + cpusPerTask) - params.put("mem", "" + memPerTask) - val serializedTask = Utils.serialize(tasks(i)) + params.put("cpus", "" + CPUS_PER_TASK) + params.put("mem", "" + MEM_PER_TASK) + val serializedTask = Utils.serialize(task) logDebug("Serialized size: " + serializedTask.size) return Some(new TaskDescription(taskId, offer.getSlaveId, "task_" + taskId, params, serializedTask)) } + case _ => } } return None @@ -138,6 +191,8 @@ extends Job with Logging launched(index) = false sched.taskIdToJobId.remove(tid) tasksLaunched -= 1 + // Re-enqueue the task as pending + addPendingTask(index) } else { logInfo("Ignoring task-lost event for TID " + tid + " because task " + index + " is already finished")