diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala index 5adff032eb1f5f8bba719fd9a568a4e4f946d164..470be69e504657b159cf44a6d2fee665bfd7193d 100644 --- a/src/scala/spark/MesosScheduler.scala +++ b/src/scala/spark/MesosScheduler.scala @@ -5,9 +5,10 @@ import java.util.{ArrayList => JArrayList} import java.util.{List => JList} import java.util.{HashMap => JHashMap} +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet import scala.collection.mutable.Map import scala.collection.mutable.Queue -import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ import mesos.{Scheduler => MScheduler} @@ -36,6 +37,7 @@ extends MScheduler with spark.Scheduler with Logging private var activeJobsQueue = new Queue[Job] private var taskIdToJobId = new HashMap[Int, Int] + private var jobTasks = new HashMap[Int, HashSet[Int]] private var nextJobId = 0 @@ -95,18 +97,20 @@ extends MScheduler with spark.Scheduler with Logging waitForRegister() val jobId = newJobId() val myJob = new SimpleJob(this, tasks, jobId) - try { this.synchronized { - this.activeJobs(myJob.jobId) = myJob - this.activeJobsQueue += myJob + activeJobs(jobId) = myJob + activeJobsQueue += myJob + jobTasks(jobId) = new HashSet() } driver.reviveOffers(); return myJob.join(); } finally { this.synchronized { - this.activeJobs.remove(myJob.jobId) - this.activeJobsQueue.dequeueAll(x => (x == myJob)) + activeJobs -= jobId + activeJobsQueue.dequeueAll(x => (x == myJob)) + taskIdToJobId --= jobTasks(jobId) + jobTasks.remove(jobId) } } } @@ -147,6 +151,7 @@ extends MScheduler with spark.Scheduler with Logging case Some(task) => tasks.add(task) taskIdToJobId(task.getTaskId) = job.getId + jobTasks(job.getId) += task.getTaskId availableCpus(i) -= task.getParams.get("cpus").toInt availableMem(i) -= task.getParams.get("mem").toInt launchedTask = true @@ -182,6 +187,7 @@ extends MScheduler with spark.Scheduler with Logging } if (isFinished(status.getState)) { taskIdToJobId.remove(status.getTaskId) + jobTasks(jobId) -= status.getTaskId } case None => logInfo("TID " + status.getTaskId + " already finished")