diff --git a/src/scala/spark/Job.scala b/src/scala/spark/Job.scala new file mode 100644 index 0000000000000000000000000000000000000000..6b01307adcc5e06d0e65ffa25f461b0406270c86 --- /dev/null +++ b/src/scala/spark/Job.scala @@ -0,0 +1,16 @@ +package spark + +import mesos._ + +/** + * Trait representing a parallel job in MesosScheduler. Schedules the + * job by implementing various callbacks. + */ +trait Job { + def slaveOffer(s: SlaveOffer, availableCpus: Int, availableMem: Int) + : Option[TaskDescription] + + def statusUpdate(t: TaskStatus): Unit + + def error(code: Int, message: String): Unit +} diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala index ecb95d7c0bc09fc98d690b7a44d62e375f6c2fba..cb78c2b58273da548b8aa2be17e8fb6a36ae1490 100644 --- a/src/scala/spark/MesosScheduler.scala +++ b/src/scala/spark/MesosScheduler.scala @@ -1,13 +1,16 @@ package spark import java.io.File +import java.util.{ArrayList => JArrayList} +import java.util.{List => JList} +import java.util.{HashMap => JHashMap} import scala.collection.mutable.Map import scala.collection.mutable.Queue import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ -import mesos.{Scheduler => NScheduler} +import mesos.{Scheduler => MScheduler} import mesos._ // The main Scheduler implementation, which talks to Mesos. Clients are expected @@ -23,22 +26,20 @@ import mesos._ // all the offers to the Job and have it load-balance. private class MesosScheduler( master: String, frameworkName: String, execArg: Array[Byte]) -extends NScheduler with spark.Scheduler with Logging +extends MScheduler with spark.Scheduler with Logging { - // Lock used by runTasks to ensure only one thread can be in it - val runTasksMutex = new Object() - // Lock used to wait for scheduler to be registered var isRegistered = false val registeredLock = new Object() - // Current callback object (may be null) - var activeJobsQueue = new Queue[Int] var activeJobs = new HashMap[Int, Job] - private var nextJobId = 0 + var activeJobsQueue = new Queue[Job] + private[spark] var taskIdToJobId = new HashMap[Int, Int] + + private var nextJobId = 0 - def newJobId(): Int = { + def newJobId(): Int = this.synchronized { val id = nextJobId nextJobId += 1 return id @@ -60,9 +61,9 @@ extends NScheduler with spark.Scheduler with Logging new Thread("Spark scheduler") { setDaemon(true) override def run { - val ns = MesosScheduler.this - ns.driver = new MesosSchedulerDriver(ns, master) - ns.driver.run() + val sched = MesosScheduler.this + sched.driver = new MesosSchedulerDriver(sched, master) + sched.driver.run() } }.start } @@ -72,25 +73,26 @@ extends NScheduler with spark.Scheduler with Logging override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo = new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg) + /** + * The primary means to submit a job to the scheduler. Given a list of tasks, + * runs them and returns an array of the results. + */ override def runTasks[T: ClassManifest](tasks: Array[Task[T]]): Array[T] = { - var jobId = 0 waitForRegister() - this.synchronized { - jobId = newJobId() - } + val jobId = newJobId() val myJob = new SimpleJob(this, tasks, jobId) try { this.synchronized { this.activeJobs(myJob.jobId) = myJob - this.activeJobsQueue += myJob.jobId + this.activeJobsQueue += myJob } driver.reviveOffers(); myJob.join(); } finally { this.synchronized { this.activeJobs.remove(myJob.jobId) - this.activeJobsQueue.dequeueAll(x => (x == myJob.jobId)) + this.activeJobsQueue.dequeueAll(x => (x == myJob)) } } @@ -116,35 +118,34 @@ extends NScheduler with spark.Scheduler with Logging } override def resourceOffer( - d: SchedulerDriver, oid: String, offers: java.util.List[SlaveOffer]) { + d: SchedulerDriver, oid: String, offers: JList[SlaveOffer]) { synchronized { - val tasks = new java.util.ArrayList[TaskDescription] + val tasks = new JArrayList[TaskDescription] val availableCpus = offers.map(_.getParams.get("cpus").toInt) val availableMem = offers.map(_.getParams.get("mem").toInt) - var launchedTask = true - for (jobId <- activeJobsQueue) { - launchedTask = true - while (launchedTask) { + var launchedTask = false + for (job <- activeJobsQueue) { + do { launchedTask = false for (i <- 0 until offers.size.toInt) { try { - activeJobs(jobId).slaveOffer(offers.get(i), availableCpus(i), availableMem(i)) match { + job.slaveOffer(offers(i), availableCpus(i), availableMem(i)) match { case Some(task) => tasks.add(task) availableCpus(i) -= task.getParams.get("cpus").toInt availableMem(i) -= task.getParams.get("mem").toInt - launchedTask = launchedTask || true + launchedTask = true case None => {} } } catch { case e: Exception => logError("Exception in resourceOffer", e) } } - } + } while (launchedTask) } - val params = new java.util.HashMap[String, String] + val params = new JHashMap[String, String] params.put("timeout", "1") - d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout + d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout? } } @@ -167,8 +168,10 @@ extends NScheduler with spark.Scheduler with Logging } override def error(d: SchedulerDriver, code: Int, message: String) { + logError("Mesos error: %s (error code: %d)".format(message, code)) synchronized { if (activeJobs.size > 0) { + // Have each job throw a SparkException with the error for ((jobId, activeJob) <- activeJobs) { try { activeJob.error(code, message) @@ -177,7 +180,9 @@ extends NScheduler with spark.Scheduler with Logging } } } else { - logError("Mesos error: %s (error code: %d)".format(message, code)) + // No jobs are active but we still got an error. Just exit since this + // must mean the error is during registration. + // It might be good to do something smarter here in the future. System.exit(1) } } @@ -191,156 +196,3 @@ extends NScheduler with spark.Scheduler with Logging // TODO: query Mesos for number of cores override def numCores() = System.getProperty("spark.default.parallelism", "2").toInt } - - -// Trait representing an object that manages a parallel operation by -// implementing various scheduler callbacks. -trait Job { - def slaveOffer(s: SlaveOffer, availableCpus: Int, availableMem: Int): Option[TaskDescription] - def statusUpdate(t: TaskStatus): Unit - def error(code: Int, message: String): Unit -} - - -class SimpleJob[T: ClassManifest]( - sched: MesosScheduler, tasks: Array[Task[T]], val jobId: Int) -extends Job with Logging -{ - // Maximum time to wait to run a task in a preferred location (in ms) - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong - - val callingThread = currentThread - val numTasks = tasks.length - val results = new Array[T](numTasks) - val launched = new Array[Boolean](numTasks) - val finished = new Array[Boolean](numTasks) - val tidToIndex = Map[Int, Int]() - - var allFinished = false - val joinLock = new Object() - - var errorHappened = false - var errorCode = 0 - var errorMessage = "" - - var tasksLaunched = 0 - var tasksFinished = 0 - var lastPreferredLaunchTime = System.currentTimeMillis - - def setAllFinished() { - joinLock.synchronized { - allFinished = true - joinLock.notifyAll() - } - } - - def join() { - joinLock.synchronized { - while (!allFinished) - joinLock.wait() - } - } - - def slaveOffer(offer: SlaveOffer, availableCpus: Int, availableMem: Int): Option[TaskDescription] = { - if (tasksLaunched < numTasks) { - var checkPrefVals: Array[Boolean] = Array(true) - val time = System.currentTimeMillis - if (time - lastPreferredLaunchTime > LOCALITY_WAIT) - checkPrefVals = Array(true, false) // Allow non-preferred tasks - // TODO: Make desiredCpus and desiredMem configurable - val desiredCpus = 1 - val desiredMem = 500 - if ((availableCpus < desiredCpus) || (availableMem < desiredMem)) - return None - for (checkPref <- checkPrefVals; i <- 0 until numTasks) { - if (!launched(i) && (!checkPref || - tasks(i).preferredLocations.contains(offer.getHost) || - tasks(i).preferredLocations.isEmpty)) - { - val taskId = sched.newTaskId() - sched.taskIdToJobId(taskId) = jobId - tidToIndex(taskId) = i - val preferred = if(checkPref) "preferred" else "non-preferred" - val message = - "Starting task %d as jobId %d, TID %s on slave %s: %s (%s)".format( - i, jobId, taskId, offer.getSlaveId, offer.getHost, preferred) - logInfo(message) - tasks(i).markStarted(offer) - launched(i) = true - tasksLaunched += 1 - if (checkPref) - lastPreferredLaunchTime = time - val params = new java.util.HashMap[String, String] - params.put("cpus", "" + desiredCpus) - params.put("mem", "" + desiredMem) - val serializedTask = Utils.serialize(tasks(i)) - //logInfo("Serialized size: " + serializedTask.size) - return Some(new TaskDescription(taskId, offer.getSlaveId, - "task_" + taskId, params, serializedTask)) - } - } - } - return None - } - - def statusUpdate(status: TaskStatus) { - status.getState match { - case TaskState.TASK_FINISHED => - taskFinished(status) - case TaskState.TASK_LOST => - taskLost(status) - case TaskState.TASK_FAILED => - taskLost(status) - case TaskState.TASK_KILLED => - taskLost(status) - case _ => - } - } - - def taskFinished(status: TaskStatus) { - val tid = status.getTaskId - val index = tidToIndex(tid) - if (!finished(index)) { - tasksFinished += 1 - logInfo("Finished job %d TID %d (progress: %d/%d)".format( - jobId, tid, tasksFinished, numTasks)) - // Deserialize task result - val result = Utils.deserialize[TaskResult[T]](status.getData) - results(index) = result.value - // Update accumulators - Accumulators.add(callingThread, result.accumUpdates) - // Mark finished and stop if we've finished all the tasks - finished(index) = true - // Remove TID -> jobId mapping from sched - sched.taskIdToJobId.remove(tid) - if (tasksFinished == numTasks) - setAllFinished() - } else { - logInfo("Ignoring task-finished event for TID " + tid + - " because task " + index + " is already finished") - } - } - - def taskLost(status: TaskStatus) { - val tid = status.getTaskId - val index = tidToIndex(tid) - if (!finished(index)) { - logInfo("Lost job " + jobId + " TID " + tid) - launched(index) = false - sched.taskIdToJobId.remove(tid) - tasksLaunched -= 1 - } else { - logInfo("Ignoring task-lost event for TID " + tid + - " because task " + index + " is already finished") - } - } - - def error(code: Int, message: String) { - // Save the error message - errorHappened = true - errorCode = code - errorMessage = message - // Indicate to caller thread that we're done - setAllFinished() - } -} diff --git a/src/scala/spark/SimpleJob.scala b/src/scala/spark/SimpleJob.scala new file mode 100644 index 0000000000000000000000000000000000000000..9664a4457807dc1eecccd0199b772d23ef93a6f5 --- /dev/null +++ b/src/scala/spark/SimpleJob.scala @@ -0,0 +1,155 @@ +package spark + +import java.util.{HashMap => JHashMap} + +import scala.collection.mutable.HashMap + +import mesos._ + + +/** + * A simple implementation of Job that just runs each task in an array. + */ +class SimpleJob[T: ClassManifest]( + sched: MesosScheduler, tasks: Array[Task[T]], val jobId: Int) +extends Job with Logging +{ + // Maximum time to wait to run a task in a preferred location (in ms) + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + + val callingThread = currentThread + val numTasks = tasks.length + val results = new Array[T](numTasks) + val launched = new Array[Boolean](numTasks) + val finished = new Array[Boolean](numTasks) + val tidToIndex = HashMap[Int, Int]() + + var allFinished = false + val joinLock = new Object() + + var errorHappened = false + var errorCode = 0 + var errorMessage = "" + + var tasksLaunched = 0 + var tasksFinished = 0 + var lastPreferredLaunchTime = System.currentTimeMillis + + val cpusPerTask = System.getProperty("spark.task.cpus", "1").toInt + val memPerTask = System.getProperty("spark.task.mem", "512").toInt + + def setAllFinished() { + joinLock.synchronized { + allFinished = true + joinLock.notifyAll() + } + } + + def join() { + joinLock.synchronized { + while (!allFinished) + joinLock.wait() + } + } + + def slaveOffer(offer: SlaveOffer, availableCpus: Int, availableMem: Int) + : Option[TaskDescription] = { + if (tasksLaunched < numTasks) { + var checkPrefVals: Array[Boolean] = Array(true) + val time = System.currentTimeMillis + if (time - lastPreferredLaunchTime > LOCALITY_WAIT) + checkPrefVals = Array(true, false) // Allow non-preferred tasks + if ((availableCpus < cpusPerTask) || (availableMem < memPerTask)) + return None + for (checkPref <- checkPrefVals; i <- 0 until numTasks) { + if (!launched(i) && (!checkPref || + tasks(i).preferredLocations.contains(offer.getHost) || + tasks(i).preferredLocations.isEmpty)) + { + val taskId = sched.newTaskId() + sched.taskIdToJobId(taskId) = jobId + tidToIndex(taskId) = i + val preferred = if(checkPref) "preferred" else "non-preferred" + val message = + "Starting task %d:%d as TID %s on slave %s: %s (%s)".format( + i, jobId, taskId, offer.getSlaveId, offer.getHost, preferred) + logInfo(message) + tasks(i).markStarted(offer) + launched(i) = true + tasksLaunched += 1 + if (checkPref) + lastPreferredLaunchTime = time + val params = new JHashMap[String, String] + params.put("cpus", "" + cpusPerTask) + params.put("mem", "" + memPerTask) + val serializedTask = Utils.serialize(tasks(i)) + logDebug("Serialized size: " + serializedTask.size) + return Some(new TaskDescription(taskId, offer.getSlaveId, + "task_" + taskId, params, serializedTask)) + } + } + } + return None + } + + def statusUpdate(status: TaskStatus) { + status.getState match { + case TaskState.TASK_FINISHED => + taskFinished(status) + case TaskState.TASK_LOST => + taskLost(status) + case TaskState.TASK_FAILED => + taskLost(status) + case TaskState.TASK_KILLED => + taskLost(status) + case _ => + } + } + + def taskFinished(status: TaskStatus) { + val tid = status.getTaskId + val index = tidToIndex(tid) + if (!finished(index)) { + tasksFinished += 1 + logInfo("Finished TID %d (progress: %d/%d)".format( + tid, tasksFinished, numTasks)) + // Deserialize task result + val result = Utils.deserialize[TaskResult[T]](status.getData) + results(index) = result.value + // Update accumulators + Accumulators.add(callingThread, result.accumUpdates) + // Mark finished and stop if we've finished all the tasks + finished(index) = true + // Remove TID -> jobId mapping from sched + sched.taskIdToJobId.remove(tid) + if (tasksFinished == numTasks) + setAllFinished() + } else { + logInfo("Ignoring task-finished event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def taskLost(status: TaskStatus) { + val tid = status.getTaskId + val index = tidToIndex(tid) + if (!finished(index)) { + logInfo("Lost TID %d (task %d:%d)".format(tid, jobId, index)) + launched(index) = false + sched.taskIdToJobId.remove(tid) + tasksLaunched -= 1 + } else { + logInfo("Ignoring task-lost event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def error(code: Int, message: String) { + // Save the error message + errorHappened = true + errorCode = code + errorMessage = message + // Indicate to caller thread that we're done + setAllFinished() + } +}