diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 3e3f1ad031e663deaa3bddf0624b2e69d0f806e8..67446da0a8b8d0a45fb4e1ff68505981004ba95a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -93,10 +93,12 @@ private[spark] class TaskSchedulerImpl(
   // Incrementing task IDs
   val nextTaskId = new AtomicLong(0)
 
-  // Number of tasks running on each executor
-  private val executorIdToTaskCount = new HashMap[String, Int]
+  // IDs of the tasks running on each executor
+  private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]]
 
-  def runningTasksByExecutors(): Map[String, Int] = executorIdToTaskCount.toMap
+  def runningTasksByExecutors(): Map[String, Int] = {
+    executorIdToRunningTaskIds.toMap.mapValues(_.size)
+  }
 
   // The set of executors we have on each host; this is used to compute hostsAlive, which
   // in turn is used to decide when we can attain data locality on a given host
@@ -264,7 +266,7 @@ private[spark] class TaskSchedulerImpl(
             val tid = task.taskId
             taskIdToTaskSetManager(tid) = taskSet
             taskIdToExecutorId(tid) = execId
-            executorIdToTaskCount(execId) += 1
+            executorIdToRunningTaskIds(execId).add(tid)
             availableCpus(i) -= CPUS_PER_TASK
             assert(availableCpus(i) >= 0)
             launchedTask = true
@@ -294,11 +296,11 @@ private[spark] class TaskSchedulerImpl(
       if (!hostToExecutors.contains(o.host)) {
         hostToExecutors(o.host) = new HashSet[String]()
       }
-      if (!executorIdToTaskCount.contains(o.executorId)) {
+      if (!executorIdToRunningTaskIds.contains(o.executorId)) {
         hostToExecutors(o.host) += o.executorId
         executorAdded(o.executorId, o.host)
         executorIdToHost(o.executorId) = o.host
-        executorIdToTaskCount(o.executorId) = 0
+        executorIdToRunningTaskIds(o.executorId) = HashSet[Long]()
         newExecAvail = true
       }
       for (rack <- getRackForHost(o.host)) {
@@ -349,38 +351,34 @@ private[spark] class TaskSchedulerImpl(
     var reason: Option[ExecutorLossReason] = None
     synchronized {
       try {
-        if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
-          // We lost this entire executor, so remember that it's gone
-          val execId = taskIdToExecutorId(tid)
-
-          if (executorIdToTaskCount.contains(execId)) {
-            reason = Some(
-              SlaveLost(s"Task $tid was lost, so marking the executor as lost as well."))
-            removeExecutor(execId, reason.get)
-            failedExecutor = Some(execId)
-          }
-        }
         taskIdToTaskSetManager.get(tid) match {
           case Some(taskSet) =>
-            if (TaskState.isFinished(state)) {
-              taskIdToTaskSetManager.remove(tid)
-              taskIdToExecutorId.remove(tid).foreach { execId =>
-                if (executorIdToTaskCount.contains(execId)) {
-                  executorIdToTaskCount(execId) -= 1
-                }
+            if (state == TaskState.LOST) {
+              // TaskState.LOST is only used by the deprecated Mesos fine-grained scheduling mode,
+              // where each executor corresponds to a single task, so mark the executor as failed.
+              val execId = taskIdToExecutorId.getOrElse(tid, throw new IllegalStateException(
+                "taskIdToTaskSetManager.contains(tid) <=> taskIdToExecutorId.contains(tid)"))
+              if (executorIdToRunningTaskIds.contains(execId)) {
+                reason = Some(
+                  SlaveLost(s"Task $tid was lost, so marking the executor as lost as well."))
+                removeExecutor(execId, reason.get)
+                failedExecutor = Some(execId)
               }
             }
-            if (state == TaskState.FINISHED) {
-              taskSet.removeRunningTask(tid)
-              taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
-            } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
+            if (TaskState.isFinished(state)) {
+              cleanupTaskState(tid)
               taskSet.removeRunningTask(tid)
-              taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
+              if (state == TaskState.FINISHED) {
+                taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
+              } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
+                taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
+              }
             }
           case None =>
             logError(
               ("Ignoring update with state %s for TID %s because its task set is gone (this is " +
-                "likely the result of receiving duplicate task finished status updates)")
+                "likely the result of receiving duplicate task finished status updates) or its " +
+                "executor has been marked as failed.")
                 .format(state, tid))
         }
       } catch {
@@ -491,7 +489,7 @@ private[spark] class TaskSchedulerImpl(
     var failedExecutor: Option[String] = None
 
     synchronized {
-      if (executorIdToTaskCount.contains(executorId)) {
+      if (executorIdToRunningTaskIds.contains(executorId)) {
         val hostPort = executorIdToHost(executorId)
         logExecutorLoss(executorId, hostPort, reason)
         removeExecutor(executorId, reason)
@@ -533,13 +531,31 @@ private[spark] class TaskSchedulerImpl(
       logError(s"Lost executor $executorId on $hostPort: $reason")
   }
 
+  /**
+   * Cleans up the TaskScheduler's state for tracking the given task.
+   */
+  private def cleanupTaskState(tid: Long): Unit = {
+    taskIdToTaskSetManager.remove(tid)
+    taskIdToExecutorId.remove(tid).foreach { executorId =>
+      executorIdToRunningTaskIds.get(executorId).foreach { _.remove(tid) }
+    }
+  }
+
   /**
    * Remove an executor from all our data structures and mark it as lost. If the executor's loss
    * reason is not yet known, do not yet remove its association with its host nor update the status
    * of any running tasks, since the loss reason defines whether we'll fail those tasks.
    */
   private def removeExecutor(executorId: String, reason: ExecutorLossReason) {
-    executorIdToTaskCount -= executorId
+    // The tasks on the lost executor may not send any more status updates (because the executor
+    // has been lost), so they should be cleaned up here.
+    executorIdToRunningTaskIds.remove(executorId).foreach { taskIds =>
+      logDebug("Cleaning up TaskScheduler state for tasks " +
+        s"${taskIds.mkString("[", ",", "]")} on failed executor $executorId")
+      // We do not notify the TaskSetManager of the task failures because that will
+      // happen below in the rootPool.executorLost() call.
+      taskIds.foreach(cleanupTaskState)
+    }
 
     val host = executorIdToHost(executorId)
     val execs = hostToExecutors.getOrElse(host, new HashSet)
@@ -577,11 +593,11 @@ private[spark] class TaskSchedulerImpl(
   }
 
   def isExecutorAlive(execId: String): Boolean = synchronized {
-    executorIdToTaskCount.contains(execId)
+    executorIdToRunningTaskIds.contains(execId)
   }
 
   def isExecutorBusy(execId: String): Boolean = synchronized {
-    executorIdToTaskCount.getOrElse(execId, -1) > 0
+    executorIdToRunningTaskIds.get(execId).exists(_.nonEmpty)
   }
 
   // By default, rack is unknown
diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
index e29eb8552e134cc448707e49d764a4dcfabc667f..05dad7a4b86adb7608d39df0517c482e999b7f9d 100644
--- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
@@ -433,10 +433,11 @@ class StandaloneDynamicAllocationSuite
     assert(executors.size === 2)
 
     // simulate running a task on the executor
-    val getMap = PrivateMethod[mutable.HashMap[String, Int]]('executorIdToTaskCount)
+    val getMap =
+      PrivateMethod[mutable.HashMap[String, mutable.HashSet[Long]]]('executorIdToRunningTaskIds)
     val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
-    val executorIdToTaskCount = taskScheduler invokePrivate getMap()
-    executorIdToTaskCount(executors.head) = 1
+    val executorIdToRunningTaskIds = taskScheduler invokePrivate getMap()
+    executorIdToRunningTaskIds(executors.head) = mutable.HashSet(1L)
     // kill the busy executor without force; this should fail
     assert(killExecutor(sc, executors.head, force = false).isEmpty)
     apps = getApplications()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 5dc7708530e2721f79a9ba50c5be0049399a5ce7..59bea27596c793582ca317ef3498e18662f9de76 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.scheduler
 
+import java.nio.ByteBuffer
+
 import scala.collection.mutable.HashMap
 
 import org.mockito.Matchers.{anyInt, anyString, eq => meq}
@@ -648,4 +650,70 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
     assert(taskScheduler.getExecutorsAliveOnHost("host1") === Some(Set("executor1", "executor3")))
   }
 
+  test("if an executor is lost then the state for its running tasks is cleaned up (SPARK-18553)") {
+    sc = new SparkContext("local", "TaskSchedulerImplSuite")
+    val taskScheduler = new TaskSchedulerImpl(sc)
+    taskScheduler.initialize(new FakeSchedulerBackend)
+    // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
+    new DAGScheduler(sc, taskScheduler) {
+      override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
+      override def executorAdded(execId: String, host: String) {}
+    }
+
+    val e0Offers = IndexedSeq(WorkerOffer("executor0", "host0", 1))
+    val attempt1 = FakeTask.createTaskSet(1)
+
+    // submit attempt 1, offer resources, task gets scheduled
+    taskScheduler.submitTasks(attempt1)
+    val taskDescriptions = taskScheduler.resourceOffers(e0Offers).flatten
+    assert(1 === taskDescriptions.length)
+
+    // mark executor0 as dead
+    taskScheduler.executorLost("executor0", SlaveLost())
+    assert(!taskScheduler.isExecutorAlive("executor0"))
+    assert(!taskScheduler.hasExecutorsAliveOnHost("host0"))
+    assert(taskScheduler.getExecutorsAliveOnHost("host0").isEmpty)
+
+
+    // Check that state associated with the lost task attempt is cleaned up:
+    assert(taskScheduler.taskIdToExecutorId.isEmpty)
+    assert(taskScheduler.taskIdToTaskSetManager.isEmpty)
+    assert(taskScheduler.runningTasksByExecutors().get("executor0").isEmpty)
+  }
+
+  test("if a task finishes with TaskState.LOST its executor is marked as dead") {
+    sc = new SparkContext("local", "TaskSchedulerImplSuite")
+    val taskScheduler = new TaskSchedulerImpl(sc)
+    taskScheduler.initialize(new FakeSchedulerBackend)
+    // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
+    new DAGScheduler(sc, taskScheduler) {
+      override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
+      override def executorAdded(execId: String, host: String) {}
+    }
+
+    val e0Offers = IndexedSeq(WorkerOffer("executor0", "host0", 1))
+    val attempt1 = FakeTask.createTaskSet(1)
+
+    // submit attempt 1, offer resources, task gets scheduled
+    taskScheduler.submitTasks(attempt1)
+    val taskDescriptions = taskScheduler.resourceOffers(e0Offers).flatten
+    assert(1 === taskDescriptions.length)
+
+    // Report the task as failed with TaskState.LOST
+    taskScheduler.statusUpdate(
+      tid = taskDescriptions.head.taskId,
+      state = TaskState.LOST,
+      serializedData = ByteBuffer.allocate(0)
+    )
+
+    // Check that state associated with the lost task attempt is cleaned up:
+    assert(taskScheduler.taskIdToExecutorId.isEmpty)
+    assert(taskScheduler.taskIdToTaskSetManager.isEmpty)
+    assert(taskScheduler.runningTasksByExecutors().get("executor0").isEmpty)
+
+    // Check that the executor has been marked as dead
+    assert(!taskScheduler.isExecutorAlive("executor0"))
+    assert(!taskScheduler.hasExecutorsAliveOnHost("host0"))
+    assert(taskScheduler.getExecutorsAliveOnHost("host0").isEmpty)
+  }
 }