diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala index 7f82c4934836796102cf291fb2a7946d00efdba9..c8943e3d3a9d6c8185c2aa690718ba90534d33d4 100644 --- a/src/scala/spark/MesosScheduler.scala +++ b/src/scala/spark/MesosScheduler.scala @@ -3,6 +3,7 @@ package spark import java.io.File import scala.collection.mutable.Map +import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ import mesos.{Scheduler => NScheduler} @@ -31,7 +32,15 @@ extends NScheduler with spark.Scheduler val registeredLock = new Object() // Current callback object (may be null) - var activeOp: ParallelOperation = null + var activeOps = new HashMap[Int, ParallelOperation] + private var nextOpId = 0 + private[spark] var taskIdToOpId = new HashMap[Int, Int] + + def newOpId(): Int = { + val id = nextOpId + nextOpId += 1 + return id + } // Incrementing task ID private var nextTaskId = 0 @@ -62,27 +71,29 @@ extends NScheduler with spark.Scheduler new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg) override def runTasks[T: ClassManifest](tasks: Array[Task[T]]): Array[T] = { + var opId = 0 runTasksMutex.synchronized { waitForRegister() - val myOp = new SimpleParallelOperation(this, tasks) + opId = newOpId() + } + val myOp = new SimpleParallelOperation(this, tasks, opId) - try { - this.synchronized { - this.activeOp = myOp - } - driver.reviveOffers(); - myOp.join(); - } finally { - this.synchronized { - this.activeOp = null - } + try { + this.synchronized { + this.activeOps(myOp.opId) = myOp + } + driver.reviveOffers(); + myOp.join(); + } finally { + this.synchronized { + this.activeOps.remove(myOp.opId) } - - if (myOp.errorHappened) - throw new SparkException(myOp.errorMessage, myOp.errorCode) - else - return myOp.results } + + if (myOp.errorHappened) + throw new SparkException(myOp.errorMessage, myOp.errorCode) + else + return myOp.results } override def registered(d: SchedulerDriver, frameworkId: String) { @@ -104,28 +115,26 @@ extends NScheduler with spark.Scheduler d: SchedulerDriver, oid: String, offers: java.util.List[SlaveOffer]) { synchronized { val tasks = new java.util.ArrayList[TaskDescription] - if (activeOp != null) { - try { - val availableCpus = offers.map(_.getParams.get("cpus").toInt) - val availableMem = offers.map(_.getParams.get("mem").toInt) - var resourcesAvailable = true - while (resourcesAvailable) { - resourcesAvailable = false - for (i <- 0 until offers.size.toInt) { - activeOp.slaveOffer(offers.get(i), availableCpus(i), availableMem(i)) match { - case Some(task) => - tasks.add(task) - availableCpus(i) -= task.getParams.get("cpus").toInt - availableMem(i) -= task.getParams.get("mem").toInt - resourcesAvailable = resourcesAvailable || true - case None => {} - } + val availableCpus = offers.map(_.getParams.get("cpus").toInt) + val availableMem = offers.map(_.getParams.get("mem").toInt) + var resourcesAvailable = true + while (resourcesAvailable) { + resourcesAvailable = false + for (i <- 0 until offers.size.toInt; (opId, activeOp) <- activeOps) { + try { + activeOp.slaveOffer(offers.get(i), availableCpus(i), availableMem(i)) match { + case Some(task) => + tasks.add(task) + availableCpus(i) -= task.getParams.get("cpus").toInt + availableMem(i) -= task.getParams.get("mem").toInt + resourcesAvailable = resourcesAvailable || true + case None => {} } + } catch { + case e: Exception => e.printStackTrace } - } catch { - case e: Exception => e.printStackTrace } - } + } val params = new java.util.HashMap[String, String] params.put("timeout", "1") d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout @@ -135,9 +144,15 @@ extends NScheduler with spark.Scheduler override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { synchronized { try { - if (activeOp != null) { - activeOp.statusUpdate(status) + taskIdToOpId.get(status.getTaskId) match { + case Some(opId) => + if (activeOps.contains(opId)) { + activeOps(opId).statusUpdate(status) + } + case None => + println("TID " + status.getTaskId + "already finished") } + } catch { case e: Exception => e.printStackTrace } @@ -146,11 +161,13 @@ extends NScheduler with spark.Scheduler override def error(d: SchedulerDriver, code: Int, message: String) { synchronized { - if (activeOp != null) { - try { - activeOp.error(code, message) - } catch { - case e: Exception => e.printStackTrace + if (activeOps.size > 0) { + for ((opId, activeOp) <- activeOps) { + try { + activeOp.error(code, message) + } catch { + case e: Exception => e.printStackTrace + } } } else { val msg = "Mesos error: %s (error code: %d)".format(message, code) @@ -180,7 +197,7 @@ trait ParallelOperation { class SimpleParallelOperation[T: ClassManifest]( - sched: MesosScheduler, tasks: Array[Task[T]]) + sched: MesosScheduler, tasks: Array[Task[T]], val opId: Int) extends ParallelOperation { // Maximum time to wait to run a task in a preferred location (in ms) @@ -235,10 +252,10 @@ extends ParallelOperation tasks(i).preferredLocations.isEmpty)) { val taskId = sched.newTaskId() + sched.taskIdToOpId(taskId) = opId tidToIndex(taskId) = i - //printf("Starting task %d as TID %s on slave %s: %s (%s)\n", - printf("Starting task %d as TID %s on slave %s: %s (%s)", - i, taskId, offer.getSlaveId, offer.getHost, + printf("Starting task %d as opId %d, TID %s on slave %s: %s (%s)", + i, opId, taskId, offer.getSlaveId, offer.getHost, if(checkPref) "preferred" else "non-preferred") tasks(i).markStarted(offer) launched(i) = true @@ -274,7 +291,7 @@ extends ParallelOperation def taskFinished(status: TaskStatus) { val tid = status.getTaskId - print("Finished TID " + tid) + print("Finished opId " + opId + " TID " + tid) if (!finished(tidToIndex(tid))) { // Deserialize task result val result = Utils.deserialize[TaskResult[T]](status.getData) @@ -283,6 +300,8 @@ extends ParallelOperation Accumulators.add(callingThread, result.accumUpdates) // Mark finished and stop if we've finished all the tasks finished(tidToIndex(tid)) = true + // Remove TID -> opId mapping from sched + sched.taskIdToOpId.remove(tid) tasksFinished += 1 println(", finished " + tasksFinished + "/" + numTasks) @@ -295,7 +314,7 @@ extends ParallelOperation def taskLost(status: TaskStatus) { val tid = status.getTaskId - println("Lost TID " + tid) + println("Lost opId " + opId + " TID " + tid) if (!finished(tidToIndex(tid))) { launched(tidToIndex(tid)) = false tasksLaunched -= 1 diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala index 1188367bdd00a00446e1c468769f0d4f787c267f..35d3458723f7e8572afbbd189a980f5dc73a0c40 100644 --- a/src/scala/spark/SparkContext.scala +++ b/src/scala/spark/SparkContext.scala @@ -4,6 +4,17 @@ import java.io._ import java.util.UUID import scala.collection.mutable.ArrayBuffer +import scala.actors.Actor._ + +case class SparkAsyncLock(var finished: Boolean = false) { + def join() { + this.synchronized { + while (!finished) { + this.wait + } + } + } +} class SparkContext(master: String, frameworkName: String) { Broadcast.initialize(true) @@ -21,6 +32,18 @@ class SparkContext(master: String, frameworkName: String) { def broadcast[T](value: T) = new CentralizedHDFSBroadcast(value, local) //def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, local) + def fork(f: => Unit): SparkAsyncLock = { + val thisLock = new SparkAsyncLock + actor { + f + thisLock.synchronized { + thisLock.finished = true + thisLock.notifyAll() + } + } + thisLock + } + def textFile(path: String) = new HdfsTextFile(this, path) val LOCAL_REGEX = """local\[([0-9]+)\]""".r