diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala index c8943e3d3a9d6c8185c2aa690718ba90534d33d4..984a5e56377effdd4b9e38bcf6a6475e00894555 100644 --- a/src/scala/spark/MesosScheduler.scala +++ b/src/scala/spark/MesosScheduler.scala @@ -3,6 +3,7 @@ package spark import java.io.File import scala.collection.mutable.Map +import scala.collection.mutable.Queue import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ @@ -32,6 +33,7 @@ extends NScheduler with spark.Scheduler val registeredLock = new Object() // Current callback object (may be null) + var activeOpsQueue = new Queue[Int] var activeOps = new HashMap[Int, ParallelOperation] private var nextOpId = 0 private[spark] var taskIdToOpId = new HashMap[Int, Int] @@ -72,8 +74,8 @@ extends NScheduler with spark.Scheduler override def runTasks[T: ClassManifest](tasks: Array[Task[T]]): Array[T] = { var opId = 0 - runTasksMutex.synchronized { - waitForRegister() + waitForRegister() + this.synchronized { opId = newOpId() } val myOp = new SimpleParallelOperation(this, tasks, opId) @@ -81,12 +83,14 @@ extends NScheduler with spark.Scheduler try { this.synchronized { this.activeOps(myOp.opId) = myOp + this.activeOpsQueue += myOp.opId } driver.reviveOffers(); myOp.join(); } finally { this.synchronized { this.activeOps.remove(myOp.opId) + this.activeOpsQueue.dequeueAll(x => (x == myOp.opId)) } } @@ -117,21 +121,24 @@ extends NScheduler with spark.Scheduler val tasks = new java.util.ArrayList[TaskDescription] val availableCpus = offers.map(_.getParams.get("cpus").toInt) val availableMem = offers.map(_.getParams.get("mem").toInt) - var resourcesAvailable = true - while (resourcesAvailable) { - resourcesAvailable = false - for (i <- 0 until offers.size.toInt; (opId, activeOp) <- activeOps) { - try { - activeOp.slaveOffer(offers.get(i), availableCpus(i), availableMem(i)) match { - case Some(task) => - tasks.add(task) - availableCpus(i) -= task.getParams.get("cpus").toInt - availableMem(i) -= task.getParams.get("mem").toInt - resourcesAvailable = resourcesAvailable || true - case None => {} + var launchedTask = true + for (opId <- activeOpsQueue) { + launchedTask = true + while (launchedTask) { + launchedTask = false + for (i <- 0 until offers.size.toInt) { + try { + activeOps(opId).slaveOffer(offers.get(i), availableCpus(i), availableMem(i)) match { + case Some(task) => + tasks.add(task) + availableCpus(i) -= task.getParams.get("cpus").toInt + availableMem(i) -= task.getParams.get("mem").toInt + launchedTask = launchedTask || true + case None => {} + } + } catch { + case e: Exception => e.printStackTrace } - } catch { - case e: Exception => e.printStackTrace } } } @@ -317,6 +324,7 @@ extends ParallelOperation println("Lost opId " + opId + " TID " + tid) if (!finished(tidToIndex(tid))) { launched(tidToIndex(tid)) = false + sched.taskIdToOpId.remove(tid) tasksLaunched -= 1 } else { printf("Task %s had already finished, so ignoring it\n", tidToIndex(tid)) diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala index 35d3458723f7e8572afbbd189a980f5dc73a0c40..d5d4db4678eb4df00df59350286eea879d52e276 100644 --- a/src/scala/spark/SparkContext.scala +++ b/src/scala/spark/SparkContext.scala @@ -6,16 +6,6 @@ import java.util.UUID import scala.collection.mutable.ArrayBuffer import scala.actors.Actor._ -case class SparkAsyncLock(var finished: Boolean = false) { - def join() { - this.synchronized { - while (!finished) { - this.wait - } - } - } -} - class SparkContext(master: String, frameworkName: String) { Broadcast.initialize(true) @@ -32,18 +22,6 @@ class SparkContext(master: String, frameworkName: String) { def broadcast[T](value: T) = new CentralizedHDFSBroadcast(value, local) //def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, local) - def fork(f: => Unit): SparkAsyncLock = { - val thisLock = new SparkAsyncLock - actor { - f - thisLock.synchronized { - thisLock.finished = true - thisLock.notifyAll() - } - } - thisLock - } - def textFile(path: String) = new HdfsTextFile(this, path) val LOCAL_REGEX = """local\[([0-9]+)\]""".r