diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala
index 5adff032eb1f5f8bba719fd9a568a4e4f946d164..470be69e504657b159cf44a6d2fee665bfd7193d 100644
--- a/src/scala/spark/MesosScheduler.scala
+++ b/src/scala/spark/MesosScheduler.scala
@@ -5,9 +5,10 @@ import java.util.{ArrayList => JArrayList}
 import java.util.{List => JList}
 import java.util.{HashMap => JHashMap}
 
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
 import scala.collection.mutable.Map
 import scala.collection.mutable.Queue
-import scala.collection.mutable.HashMap
 import scala.collection.JavaConversions._
 
 import mesos.{Scheduler => MScheduler}
@@ -36,6 +37,7 @@ extends MScheduler with spark.Scheduler with Logging
   private var activeJobsQueue = new Queue[Job]
 
   private var taskIdToJobId = new HashMap[Int, Int]
+  private var jobTasks = new HashMap[Int, HashSet[Int]]
 
   private var nextJobId = 0
   
@@ -95,18 +97,20 @@ extends MScheduler with spark.Scheduler with Logging
     waitForRegister()
     val jobId = newJobId()
     val myJob = new SimpleJob(this, tasks, jobId)
-
     try {
       this.synchronized {
-        this.activeJobs(myJob.jobId) = myJob
-        this.activeJobsQueue += myJob
+        activeJobs(jobId) = myJob
+        activeJobsQueue += myJob
+        jobTasks(jobId) = new HashSet()
       }
       driver.reviveOffers();
       return myJob.join();
     } finally {
       this.synchronized {
-        this.activeJobs.remove(myJob.jobId)
-        this.activeJobsQueue.dequeueAll(x => (x == myJob))
+        activeJobs -= jobId
+        activeJobsQueue.dequeueAll(x => (x == myJob))
+        taskIdToJobId --= jobTasks(jobId)
+        jobTasks.remove(jobId)
       }
     }
   }
@@ -147,6 +151,7 @@ extends MScheduler with spark.Scheduler with Logging
                 case Some(task) =>
                   tasks.add(task)
                   taskIdToJobId(task.getTaskId) = job.getId
+                  jobTasks(job.getId) += task.getTaskId
                   availableCpus(i) -= task.getParams.get("cpus").toInt
                   availableMem(i) -= task.getParams.get("mem").toInt
                   launchedTask = true
@@ -182,6 +187,7 @@ extends MScheduler with spark.Scheduler with Logging
             }
             if (isFinished(status.getState)) {
               taskIdToJobId.remove(status.getTaskId)
+              jobTasks(jobId) -= status.getTaskId
             }
           case None =>
             logInfo("TID " + status.getTaskId + " already finished")