diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index f219c5605b643c2ff72e5e4aefda0d759ad228dd..c14c12664f5ab671c8508f13b6d9b9bf406fb048 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -23,7 +23,6 @@ import java.util.LinkedList;
 import org.apache.avro.reflect.Nullable;
 
 import org.apache.spark.TaskContext;
-import org.apache.spark.TaskKilledException;
 import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.unsafe.Platform;
@@ -291,8 +290,8 @@ public final class UnsafeInMemorySorter {
       // to avoid performance overhead. This check is added here in `loadNext()` instead of in
       // `hasNext()` because it's technically possible for the caller to be relying on
       // `getNumRecords()` instead of `hasNext()` to know when to stop.
-      if (taskContext != null && taskContext.isInterrupted()) {
-        throw new TaskKilledException();
+      if (taskContext != null) {
+        taskContext.killTaskIfInterrupted();
       }
       // This pointer points to a 4-byte record length, followed by the record's bytes
       final long recordPointer = array.get(offset + position);
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index b6323c624b7b96637679984231c6d11cac523353..9521ab86a12d5e2a062d6b8ca904d49c3b9e0eca 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -24,7 +24,6 @@ import com.google.common.io.Closeables;
 
 import org.apache.spark.SparkEnv;
 import org.apache.spark.TaskContext;
-import org.apache.spark.TaskKilledException;
 import org.apache.spark.io.NioBufferedFileInputStream;
 import org.apache.spark.serializer.SerializerManager;
 import org.apache.spark.storage.BlockId;
@@ -102,8 +101,8 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
     // to avoid performance overhead. This check is added here in `loadNext()` instead of in
     // `hasNext()` because it's technically possible for the caller to be relying on
     // `getNumRecords()` instead of `hasNext()` to know when to stop.
-    if (taskContext != null && taskContext.isInterrupted()) {
-      throw new TaskKilledException();
+    if (taskContext != null) {
+      taskContext.killTaskIfInterrupted();
     }
     recordLength = din.readInt();
     keyPrefix = din.readLong();
diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
index 5c262bcbddf76a1f02298c803ff925ae107a3767..7f2c0068174b57b3a63f6dcb08e308b7cec65e48 100644
--- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
+++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
@@ -33,11 +33,8 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator
     // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
     // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
     // introduces an expensive read fence.
-    if (context.isInterrupted) {
-      throw new TaskKilledException
-    } else {
-      delegate.hasNext
-    }
+    context.killTaskIfInterrupted()
+    delegate.hasNext
   }
 
   def next(): T = delegate.next()
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 0e36a30c933d00e8f1aca60d8537617dc19d44db..0225fd605607468b782fbaf8852ac7d09a6f910a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -2249,6 +2249,24 @@ class SparkContext(config: SparkConf) extends Logging {
     dagScheduler.cancelStage(stageId, None)
   }
 
+  /**
+   * Kill and reschedule the given task attempt. Task ids can be obtained from the Spark UI
+   * or through SparkListener.onTaskStart.
+   *
+   * @param taskId the task ID to kill. This id uniquely identifies the task attempt.
+   * @param interruptThread whether to interrupt the thread running the task.
+   * @param reason the reason for killing the task, which should be a short string. If a task
+   *   is killed multiple times with different reasons, only one reason will be reported.
+   *
+   * @return Whether the task was successfully killed.
+   */
+  def killTaskAttempt(
+      taskId: Long,
+      interruptThread: Boolean = true,
+      reason: String = "killed via SparkContext.killTaskAttempt"): Boolean = {
+    dagScheduler.killTaskAttempt(taskId, interruptThread, reason)
+  }
+
   /**
    * Clean a closure to make it ready to serialized and send to tasks
    * (removes unreferenced variables in $outer's, updates REPL variables)
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 5acfce17593b30f2513e9617887f139ba64012a2..0b87cd503d4fa722a81585b603cd3e6fcd0becf9 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -184,6 +184,16 @@ abstract class TaskContext extends Serializable {
   @DeveloperApi
   def getMetricsSources(sourceName: String): Seq[Source]
 
+  /**
+   * If the task is interrupted, throws TaskKilledException with the reason for the interrupt.
+   */
+  private[spark] def killTaskIfInterrupted(): Unit
+
+  /**
+   * If the task is interrupted, the reason this task was killed, otherwise None.
+   */
+  private[spark] def getKillReason(): Option[String]
+
   /**
    * Returns the manager for this task's managed memory.
    */
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index f346cf8d65806bfc78ebea1fdce523d2613fb46f..8cd1d1c96aa0a54a4a872f75bd6991a61e6dbc6a 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -59,8 +59,8 @@ private[spark] class TaskContextImpl(
   /** List of callback functions to execute when the task fails. */
   @transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener]
 
-  // Whether the corresponding task has been killed.
-  @volatile private var interrupted: Boolean = false
+  // If defined, the corresponding task has been killed and this option contains the reason.
+  @volatile private var reasonIfKilled: Option[String] = None
 
   // Whether the task has completed.
   private var completed: Boolean = false
@@ -140,8 +140,19 @@ private[spark] class TaskContextImpl(
   }
 
   /** Marks the task for interruption, i.e. cancellation. */
-  private[spark] def markInterrupted(): Unit = {
-    interrupted = true
+  private[spark] def markInterrupted(reason: String): Unit = {
+    reasonIfKilled = Some(reason)
+  }
+
+  private[spark] override def killTaskIfInterrupted(): Unit = {
+    val reason = reasonIfKilled
+    if (reason.isDefined) {
+      throw new TaskKilledException(reason.get)
+    }
+  }
+
+  private[spark] override def getKillReason(): Option[String] = {
+    reasonIfKilled
   }
 
   @GuardedBy("this")
@@ -149,7 +160,7 @@ private[spark] class TaskContextImpl(
 
   override def isRunningLocally(): Boolean = false
 
-  override def isInterrupted(): Boolean = interrupted
+  override def isInterrupted(): Boolean = reasonIfKilled.isDefined
 
   override def getLocalProperty(key: String): String = localProperties.getProperty(key)
 
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 8c1b5f7bf0d9b8283932c78cc0abb6f3ff08a8ac..a76283e33fa65ab037edfb484cd61d1ec5389a6e 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -212,8 +212,8 @@ case object TaskResultLost extends TaskFailedReason {
  * Task was killed intentionally and needs to be rescheduled.
  */
 @DeveloperApi
-case object TaskKilled extends TaskFailedReason {
-  override def toErrorString: String = "TaskKilled (killed intentionally)"
+case class TaskKilled(reason: String) extends TaskFailedReason {
+  override def toErrorString: String = s"TaskKilled ($reason)"
   override def countTowardsTaskFailures: Boolean = false
 }
 
diff --git a/core/src/main/scala/org/apache/spark/TaskKilledException.scala b/core/src/main/scala/org/apache/spark/TaskKilledException.scala
index ad487c4efb87a3d7dda8d551abced46ea7173087..9dbf0d493be1156bcf82cca666ed11e1135aac59 100644
--- a/core/src/main/scala/org/apache/spark/TaskKilledException.scala
+++ b/core/src/main/scala/org/apache/spark/TaskKilledException.scala
@@ -24,4 +24,6 @@ import org.apache.spark.annotation.DeveloperApi
  * Exception thrown when a task is explicitly killed (i.e., task failure is expected).
  */
 @DeveloperApi
-class TaskKilledException extends RuntimeException
+class TaskKilledException(val reason: String) extends RuntimeException {
+  def this() = this("unknown reason")
+}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 04ae97ed3ccbe91b297960ba3ced973f66a81e5e..b0dd2fc187bafe439580fb11fbd5afbc5ee0d4b7 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -215,7 +215,7 @@ private[spark] class PythonRunner(
 
           case e: Exception if context.isInterrupted =>
             logDebug("Exception thrown after task interruption", e)
-            throw new TaskKilledException
+            throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))
 
           case e: Exception if env.isStopped =>
             logDebug("Exception thrown after context is stopped", e)
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index b376ecd301eab036c8367e60c6bcd1ec2d126d67..ba0096d874567ad2726d0a50b76a6a7401dc38d9 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -97,11 +97,11 @@ private[spark] class CoarseGrainedExecutorBackend(
         executor.launchTask(this, taskDesc)
       }
 
-    case KillTask(taskId, _, interruptThread) =>
+    case KillTask(taskId, _, interruptThread, reason) =>
       if (executor == null) {
         exitExecutor(1, "Received KillTask command but executor was null")
       } else {
-        executor.killTask(taskId, interruptThread)
+        executor.killTask(taskId, interruptThread, reason)
       }
 
     case StopExecutor =>
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 790c1ae942474b05651b872b881468039d34f488..99b1608010ddb667a8c7a0a398492b83c1d5e7df 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -158,7 +158,7 @@ private[spark] class Executor(
     threadPool.execute(tr)
   }
 
-  def killTask(taskId: Long, interruptThread: Boolean): Unit = {
+  def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = {
     val taskRunner = runningTasks.get(taskId)
     if (taskRunner != null) {
       if (taskReaperEnabled) {
@@ -168,7 +168,8 @@ private[spark] class Executor(
             case Some(existingReaper) => interruptThread && !existingReaper.interruptThread
           }
           if (shouldCreateReaper) {
-            val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread)
+            val taskReaper = new TaskReaper(
+              taskRunner, interruptThread = interruptThread, reason = reason)
             taskReaperForTask(taskId) = taskReaper
             Some(taskReaper)
           } else {
@@ -178,7 +179,7 @@ private[spark] class Executor(
         // Execute the TaskReaper from outside of the synchronized block.
         maybeNewTaskReaper.foreach(taskReaperPool.execute)
       } else {
-        taskRunner.kill(interruptThread = interruptThread)
+        taskRunner.kill(interruptThread = interruptThread, reason = reason)
       }
     }
   }
@@ -189,8 +190,9 @@ private[spark] class Executor(
    * tasks instead of taking the JVM down.
    * @param interruptThread whether to interrupt the task thread
    */
-  def killAllTasks(interruptThread: Boolean) : Unit = {
-    runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread))
+  def killAllTasks(interruptThread: Boolean, reason: String) : Unit = {
+    runningTasks.keys().asScala.foreach(t =>
+      killTask(t, interruptThread = interruptThread, reason = reason))
   }
 
   def stop(): Unit = {
@@ -217,8 +219,8 @@ private[spark] class Executor(
     val threadName = s"Executor task launch worker for task $taskId"
     private val taskName = taskDescription.name
 
-    /** Whether this task has been killed. */
-    @volatile private var killed = false
+    /** If specified, this task has been killed and this option contains the reason. */
+    @volatile private var reasonIfKilled: Option[String] = None
 
     @volatile private var threadId: Long = -1
 
@@ -239,13 +241,13 @@ private[spark] class Executor(
      */
     @volatile var task: Task[Any] = _
 
-    def kill(interruptThread: Boolean): Unit = {
-      logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
-      killed = true
+    def kill(interruptThread: Boolean, reason: String): Unit = {
+      logInfo(s"Executor is trying to kill $taskName (TID $taskId), reason: $reason")
+      reasonIfKilled = Some(reason)
       if (task != null) {
         synchronized {
           if (!finished) {
-            task.kill(interruptThread)
+            task.kill(interruptThread, reason)
           }
         }
       }
@@ -296,12 +298,13 @@ private[spark] class Executor(
 
         // If this task has been killed before we deserialized it, let's quit now. Otherwise,
         // continue executing the task.
-        if (killed) {
+        val killReason = reasonIfKilled
+        if (killReason.isDefined) {
           // Throw an exception rather than returning, because returning within a try{} block
           // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
           // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
           // for the task.
-          throw new TaskKilledException
+          throw new TaskKilledException(killReason.get)
         }
 
         logDebug("Task " + taskId + "'s epoch is " + task.epoch)
@@ -358,9 +361,7 @@ private[spark] class Executor(
         } else 0L
 
         // If the task has been killed, let's fail it.
-        if (task.killed) {
-          throw new TaskKilledException
-        }
+        task.context.killTaskIfInterrupted()
 
         val resultSer = env.serializer.newInstance()
         val beforeSerialization = System.currentTimeMillis()
@@ -426,15 +427,17 @@ private[spark] class Executor(
           setTaskFinishedAndClearInterruptStatus()
           execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
 
-        case _: TaskKilledException =>
-          logInfo(s"Executor killed $taskName (TID $taskId)")
+        case t: TaskKilledException =>
+          logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
           setTaskFinishedAndClearInterruptStatus()
-          execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+          execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))
 
-        case _: InterruptedException if task.killed =>
-          logInfo(s"Executor interrupted and killed $taskName (TID $taskId)")
+        case _: InterruptedException if task.reasonIfKilled.isDefined =>
+          val killReason = task.reasonIfKilled.getOrElse("unknown reason")
+          logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
           setTaskFinishedAndClearInterruptStatus()
-          execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+          execBackend.statusUpdate(
+            taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
 
         case CausedBy(cDE: CommitDeniedException) =>
           val reason = cDE.toTaskFailedReason
@@ -512,7 +515,8 @@ private[spark] class Executor(
    */
   private class TaskReaper(
       taskRunner: TaskRunner,
-      val interruptThread: Boolean)
+      val interruptThread: Boolean,
+      val reason: String)
     extends Runnable {
 
     private[this] val taskId: Long = taskRunner.taskId
@@ -533,7 +537,7 @@ private[spark] class Executor(
         // Only attempt to kill the task once. If interruptThread = false then a second kill
         // attempt would be a no-op and if interruptThread = true then it may not be safe or
         // effective to interrupt multiple times:
-        taskRunner.kill(interruptThread = interruptThread)
+        taskRunner.kill(interruptThread = interruptThread, reason = reason)
         // Monitor the killed task until it exits. The synchronization logic here is complicated
         // because we don't want to synchronize on the taskRunner while possibly taking a thread
         // dump, but we also need to be careful to avoid races between checking whether the task
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index d944f268755de168222b36d3b4a7725d01db83d6..09717316833a7af14453dcf715c8dec66e2b842d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -738,6 +738,15 @@ class DAGScheduler(
     eventProcessLoop.post(StageCancelled(stageId, reason))
   }
 
+  /**
+   * Kill a given task. It will be retried.
+   *
+   * @return Whether the task was successfully killed.
+   */
+  def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = {
+    taskScheduler.killTaskAttempt(taskId, interruptThread, reason)
+  }
+
   /**
    * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
    * the last fetch failure.
@@ -1353,7 +1362,7 @@ class DAGScheduler(
       case TaskResultLost =>
         // Do nothing here; the TaskScheduler handles these failures and resubmits the task.
 
-      case _: ExecutorLostFailure | TaskKilled | UnknownReason =>
+      case _: ExecutorLostFailure | _: TaskKilled | UnknownReason =>
         // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler
         // will abort the job.
     }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
index 8801a761afae31f463a97d33cd7c647d509cf9af..22db3350abfa7fde9fbfd35dfd3e2525bd0647c9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -30,8 +30,21 @@ private[spark] trait SchedulerBackend {
   def reviveOffers(): Unit
   def defaultParallelism(): Int
 
-  def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit =
+  /**
+   * Requests that an executor kills a running task.
+   *
+   * @param taskId Id of the task.
+   * @param executorId Id of the executor the task is running on.
+   * @param interruptThread Whether the executor should interrupt the task thread.
+   * @param reason The reason for the task kill.
+   */
+  def killTask(
+      taskId: Long,
+      executorId: String,
+      interruptThread: Boolean,
+      reason: String): Unit =
     throw new UnsupportedOperationException
+
   def isReady(): Boolean = true
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 70213722aae4f9cbd1dde317aad4934071c55485..46ef23f316a619c7a40eaeeb21c70b8f990f6eb5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -89,8 +89,8 @@ private[spark] abstract class Task[T](
     TaskContext.setTaskContext(context)
     taskThread = Thread.currentThread()
 
-    if (_killed) {
-      kill(interruptThread = false)
+    if (_reasonIfKilled != null) {
+      kill(interruptThread = false, _reasonIfKilled)
     }
 
     new CallerContext(
@@ -158,17 +158,17 @@ private[spark] abstract class Task[T](
   // The actual Thread on which the task is running, if any. Initialized in run().
   @volatile @transient private var taskThread: Thread = _
 
-  // A flag to indicate whether the task is killed. This is used in case context is not yet
-  // initialized when kill() is invoked.
-  @volatile @transient private var _killed = false
+  // If non-null, this task has been killed and the reason is as specified. This is used in case
+  // context is not yet initialized when kill() is invoked.
+  @volatile @transient private var _reasonIfKilled: String = null
 
   protected var _executorDeserializeTime: Long = 0
   protected var _executorDeserializeCpuTime: Long = 0
 
   /**
-   * Whether the task has been killed.
+   * If defined, this task has been killed and this option contains the reason.
    */
-  def killed: Boolean = _killed
+  def reasonIfKilled: Option[String] = Option(_reasonIfKilled)
 
   /**
    * Returns the amount of time spent deserializing the RDD and function to be run.
@@ -201,10 +201,11 @@ private[spark] abstract class Task[T](
    * be called multiple times.
    * If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread.
    */
-  def kill(interruptThread: Boolean) {
-    _killed = true
+  def kill(interruptThread: Boolean, reason: String) {
+    require(reason != null)
+    _reasonIfKilled = reason
     if (context != null) {
-      context.markInterrupted()
+      context.markInterrupted(reason)
     }
     if (interruptThread && taskThread != null) {
       taskThread.interrupt()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index cd13eebe74a9928566f83777c42f47ea11c00658..3de7d1f7de22b37552c0f0e4be29ec90b7ebfb8c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -54,6 +54,13 @@ private[spark] trait TaskScheduler {
   // Cancel a stage.
   def cancelTasks(stageId: Int, interruptThread: Boolean): Unit
 
+  /**
+   * Kills a task attempt.
+   *
+   * @return Whether the task was successfully killed.
+   */
+  def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean
+
   // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
   def setDAGScheduler(dagScheduler: DAGScheduler): Unit
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index d6225a08739dd4847d076896c8d0bf7950fe5d0b..07aea773fa6328555188a6278e013e31c5d3fb22 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -241,7 +241,7 @@ private[spark] class TaskSchedulerImpl private[scheduler](
         //    simply abort the stage.
         tsm.runningTasksSet.foreach { tid =>
           val execId = taskIdToExecutorId(tid)
-          backend.killTask(tid, execId, interruptThread)
+          backend.killTask(tid, execId, interruptThread, reason = "stage cancelled")
         }
         tsm.abort("Stage %s cancelled".format(stageId))
         logInfo("Stage %d was cancelled".format(stageId))
@@ -249,6 +249,18 @@ private[spark] class TaskSchedulerImpl private[scheduler](
     }
   }
 
+  override def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = {
+    logInfo(s"Killing task $taskId: $reason")
+    val execId = taskIdToExecutorId.get(taskId)
+    if (execId.isDefined) {
+      backend.killTask(taskId, execId.get, interruptThread, reason)
+      true
+    } else {
+      logWarning(s"Could not kill task $taskId because no task with that ID was found.")
+      false
+    }
+  }
+
   /**
    * Called to indicate that all task attempts (including speculated tasks) associated with the
    * given TaskSetManager have completed, so state associated with the TaskSetManager should be
@@ -469,7 +481,7 @@ private[spark] class TaskSchedulerImpl private[scheduler](
       taskState: TaskState,
       reason: TaskFailedReason): Unit = synchronized {
     taskSetManager.handleFailedTask(tid, taskState, reason)
-    if (!taskSetManager.isZombie && taskState != TaskState.KILLED) {
+    if (!taskSetManager.isZombie && !taskSetManager.someAttemptSucceeded(tid)) {
       // Need to revive offers again now that the task set manager state has been updated to
       // reflect failed tasks that need to be re-run.
       backend.reviveOffers()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index f4a21bca79aa4c09820dd5bb254fc6efb30ece87..a177aab5f95dec7541f9fbf0674af44e87a28afb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -101,6 +101,10 @@ private[spark] class TaskSetManager(
 
   override def runningTasks: Int = runningTasksSet.size
 
+  def someAttemptSucceeded(tid: Long): Boolean = {
+    successful(taskInfos(tid).index)
+  }
+
   // True once no more tasks should be launched for this task set manager. TaskSetManagers enter
   // the zombie state once at least one attempt of each task has completed successfully, or if the
   // task set is aborted (for example, because it was killed).  TaskSetManagers remain in the zombie
@@ -722,7 +726,11 @@ private[spark] class TaskSetManager(
       logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " +
         s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " +
         s"as the attempt ${info.attemptNumber} succeeded on ${info.host}")
-      sched.backend.killTask(attemptInfo.taskId, attemptInfo.executorId, true)
+      sched.backend.killTask(
+        attemptInfo.taskId,
+        attemptInfo.executorId,
+        interruptThread = true,
+        reason = "another attempt succeeded")
     }
     if (!successful(index)) {
       tasksSuccessful += 1
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 2898cd7d17ca08bf291228ef502d237166eb446a..6b49bd699a13a318321c273d69b6a5d7147f9079 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -40,7 +40,7 @@ private[spark] object CoarseGrainedClusterMessages {
   // Driver to executors
   case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage
 
-  case class KillTask(taskId: Long, executor: String, interruptThread: Boolean)
+  case class KillTask(taskId: Long, executor: String, interruptThread: Boolean, reason: String)
     extends CoarseGrainedClusterMessage
 
   case class KillExecutorsOnHost(host: String)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 7e2cfaccfc7ba74898c77037182ed1e80747990f..4eedaaea61195a2b941e60be2b8832a6a5d680d6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -132,10 +132,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
       case ReviveOffers =>
         makeOffers()
 
-      case KillTask(taskId, executorId, interruptThread) =>
+      case KillTask(taskId, executorId, interruptThread, reason) =>
         executorDataMap.get(executorId) match {
           case Some(executorInfo) =>
-            executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread))
+            executorInfo.executorEndpoint.send(
+              KillTask(taskId, executorId, interruptThread, reason))
           case None =>
             // Ignoring the task kill since the executor is not registered.
             logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.")
@@ -428,8 +429,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
     driverEndpoint.send(ReviveOffers)
   }
 
-  override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
-    driverEndpoint.send(KillTask(taskId, executorId, interruptThread))
+  override def killTask(
+      taskId: Long, executorId: String, interruptThread: Boolean, reason: String) {
+    driverEndpoint.send(KillTask(taskId, executorId, interruptThread, reason))
   }
 
   override def defaultParallelism(): Int = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
index 625f998cd4608ad5a39ece06da576323c5e99fc6..35509bc2f85b9028f7eeb74568f254e26c6cdea8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
@@ -34,7 +34,7 @@ private case class ReviveOffers()
 
 private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
 
-private case class KillTask(taskId: Long, interruptThread: Boolean)
+private case class KillTask(taskId: Long, interruptThread: Boolean, reason: String)
 
 private case class StopExecutor()
 
@@ -70,8 +70,8 @@ private[spark] class LocalEndpoint(
         reviveOffers()
       }
 
-    case KillTask(taskId, interruptThread) =>
-      executor.killTask(taskId, interruptThread)
+    case KillTask(taskId, interruptThread, reason) =>
+      executor.killTask(taskId, interruptThread, reason)
   }
 
   override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
@@ -143,8 +143,9 @@ private[spark] class LocalSchedulerBackend(
   override def defaultParallelism(): Int =
     scheduler.conf.getInt("spark.default.parallelism", totalCores)
 
-  override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
-    localEndpoint.send(KillTask(taskId, interruptThread))
+  override def killTask(
+      taskId: Long, executorId: String, interruptThread: Boolean, reason: String) {
+    localEndpoint.send(KillTask(taskId, interruptThread, reason))
   }
 
   override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index d161843dd22304b837fa7d30e5665c06371a3151..e53d6907bc404b48a3bca3537a0bdb5e862c2dd2 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -342,7 +342,7 @@ private[spark] object UIUtils extends Logging {
       completed: Int,
       failed: Int,
       skipped: Int,
-      killed: Int,
+      reasonToNumKilled: Map[String, Int],
       total: Int): Seq[Node] = {
     val completeWidth = "width: %s%%".format((completed.toDouble/total)*100)
     // started + completed can be > total when there are speculative tasks
@@ -354,7 +354,10 @@ private[spark] object UIUtils extends Logging {
         {completed}/{total}
         { if (failed > 0) s"($failed failed)" }
         { if (skipped > 0) s"($skipped skipped)" }
-        { if (killed > 0) s"($killed killed)" }
+        { reasonToNumKilled.toSeq.sortBy(-_._2).map {
+            case (reason, count) => s"($count killed: $reason)"
+          }
+        }
       </span>
       <div class="bar bar-completed" style={completeWidth}></div>
       <div class="bar bar-running" style={startWidth}></div>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index d217f558045f2b829ec66c48f94c470a97e54f07..18be0870746e912aed6a3c9f97a17d6cf2f78285 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -630,8 +630,8 @@ private[ui] class JobPagedTable(
       </td>
       <td class="progress-cell">
         {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks,
-        failed = job.numFailedTasks, skipped = job.numSkippedTasks, killed = job.numKilledTasks,
-        total = job.numTasks - job.numSkippedTasks)}
+        failed = job.numFailedTasks, skipped = job.numSkippedTasks,
+        reasonToNumKilled = job.reasonToNumKilled, total = job.numTasks - job.numSkippedTasks)}
       </td>
     </tr>
   }
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
index cd1b02addc789c60d13cb7ade6efc6024a8378f1..52f41298a17295cc1172eb5d31b41a0ab3b5b7bc 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
@@ -133,9 +133,9 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage
             </td>
             <td>{executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")}</td>
             <td sorttable_customkey={v.taskTime.toString}>{UIUtils.formatDuration(v.taskTime)}</td>
-            <td>{v.failedTasks + v.succeededTasks + v.killedTasks}</td>
+            <td>{v.failedTasks + v.succeededTasks + v.reasonToNumKilled.map(_._2).sum}</td>
             <td>{v.failedTasks}</td>
-            <td>{v.killedTasks}</td>
+            <td>{v.reasonToNumKilled.map(_._2).sum}</td>
             <td>{v.succeededTasks}</td>
             {if (stageData.hasInput) {
               <td sorttable_customkey={v.inputBytes.toString}>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index e87caff426436b72d3b0b13efcbb057ba6e93a18..1cf03e1541d141cd4a6daeed6f4ab63d6a7fc659 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -371,8 +371,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
       taskEnd.reason match {
         case Success =>
           execSummary.succeededTasks += 1
-        case TaskKilled =>
-          execSummary.killedTasks += 1
+        case kill: TaskKilled =>
+          execSummary.reasonToNumKilled = execSummary.reasonToNumKilled.updated(
+            kill.reason, execSummary.reasonToNumKilled.getOrElse(kill.reason, 0) + 1)
         case _ =>
           execSummary.failedTasks += 1
       }
@@ -385,9 +386,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
             stageData.completedIndices.add(info.index)
             stageData.numCompleteTasks += 1
             None
-          case TaskKilled =>
-            stageData.numKilledTasks += 1
-            Some(TaskKilled.toErrorString)
+          case kill: TaskKilled =>
+            stageData.reasonToNumKilled = stageData.reasonToNumKilled.updated(
+              kill.reason, stageData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1)
+            Some(kill.toErrorString)
           case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates
             stageData.numFailedTasks += 1
             Some(e.toErrorString)
@@ -422,8 +424,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
         taskEnd.reason match {
           case Success =>
             jobData.numCompletedTasks += 1
-          case TaskKilled =>
-            jobData.numKilledTasks += 1
+          case kill: TaskKilled =>
+            jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated(
+              kill.reason, jobData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1)
           case _ =>
             jobData.numFailedTasks += 1
         }
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index e1fa9043b6a15ef0e3cebf26dd7c1b3dd9c630c1..f4caad0f58715900283f5e620d5020529c19685f 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -300,7 +300,7 @@ private[ui] class StagePagedTable(
         <td class="progress-cell">
           {UIUtils.makeProgressBar(started = stageData.numActiveTasks,
           completed = stageData.completedIndices.size, failed = stageData.numFailedTasks,
-          skipped = 0, killed = stageData.numKilledTasks, total = info.numTasks)}
+          skipped = 0, reasonToNumKilled = stageData.reasonToNumKilled, total = info.numTasks)}
         </td>
         <td>{data.inputReadWithUnit}</td>
         <td>{data.outputWriteWithUnit}</td>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index 073f7edfc2fe95ca35034c6ebcf29ea319f72567..ac1a74ad8029d142e79b51ec08b47b0968b51dcb 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -32,7 +32,7 @@ private[spark] object UIData {
     var taskTime : Long = 0
     var failedTasks : Int = 0
     var succeededTasks : Int = 0
-    var killedTasks : Int = 0
+    var reasonToNumKilled : Map[String, Int] = Map.empty
     var inputBytes : Long = 0
     var inputRecords : Long = 0
     var outputBytes : Long = 0
@@ -64,7 +64,7 @@ private[spark] object UIData {
     var numCompletedTasks: Int = 0,
     var numSkippedTasks: Int = 0,
     var numFailedTasks: Int = 0,
-    var numKilledTasks: Int = 0,
+    var reasonToNumKilled: Map[String, Int] = Map.empty,
     /* Stages */
     var numActiveStages: Int = 0,
     // This needs to be a set instead of a simple count to prevent double-counting of rerun stages:
@@ -78,7 +78,7 @@ private[spark] object UIData {
     var numCompleteTasks: Int = _
     var completedIndices = new OpenHashSet[Int]()
     var numFailedTasks: Int = _
-    var numKilledTasks: Int = _
+    var reasonToNumKilled: Map[String, Int] = Map.empty
 
     var executorRunTime: Long = _
     var executorCpuTime: Long = _
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 4b4d2d10cbf8d7d45e6a9919d8be8af5344a4797..2cb88919c8c83540922b5f781e419a585a32bf5b 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -390,6 +390,8 @@ private[spark] object JsonProtocol {
         ("Executor ID" -> executorId) ~
         ("Exit Caused By App" -> exitCausedByApp) ~
         ("Loss Reason" -> reason.map(_.toString))
+      case taskKilled: TaskKilled =>
+        ("Kill Reason" -> taskKilled.reason)
       case _ => Utils.emptyJson
     }
     ("Reason" -> reason) ~ json
@@ -877,7 +879,10 @@ private[spark] object JsonProtocol {
           }))
         ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates)
       case `taskResultLost` => TaskResultLost
-      case `taskKilled` => TaskKilled
+      case `taskKilled` =>
+        val killReason = Utils.jsonOption(json \ "Kill Reason")
+          .map(_.extract[String]).getOrElse("unknown reason")
+        TaskKilled(killReason)
       case `taskCommitDenied` =>
         // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON
         // de/serialization logic was not added until 1.5.1. To provide backward compatibility
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
index d08a162feda03e10cd5fd419f155a9d991ef3134..2c947556dfd30d65dc434b2e294b1184cd7a2180 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -34,7 +34,7 @@ import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFor
 import org.scalatest.concurrent.Eventually
 import org.scalatest.Matchers._
 
-import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart}
 import org.apache.spark.util.Utils
 
 
@@ -540,6 +540,48 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
     }
   }
 
+  // Launches one task that will run forever. Once the SparkListener detects the task has
+  // started, kill and re-schedule it. The second run of the task will complete immediately.
+  // If this test times out, then the first version of the task wasn't killed successfully.
+  test("Killing tasks") {
+    sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
+
+    SparkContextSuite.isTaskStarted = false
+    SparkContextSuite.taskKilled = false
+    SparkContextSuite.taskSucceeded = false
+
+    val listener = new SparkListener {
+      override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
+        eventually(timeout(10.seconds)) {
+          assert(SparkContextSuite.isTaskStarted)
+        }
+        if (!SparkContextSuite.taskKilled) {
+          SparkContextSuite.taskKilled = true
+          sc.killTaskAttempt(taskStart.taskInfo.taskId, true, "first attempt will hang")
+        }
+      }
+      override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+        if (taskEnd.taskInfo.attemptNumber == 1 && taskEnd.reason == Success) {
+          SparkContextSuite.taskSucceeded = true
+        }
+      }
+    }
+    sc.addSparkListener(listener)
+    eventually(timeout(20.seconds)) {
+      sc.parallelize(1 to 1).foreach { x =>
+        // first attempt will hang
+        if (!SparkContextSuite.isTaskStarted) {
+          SparkContextSuite.isTaskStarted = true
+          Thread.sleep(9999999)
+        }
+        // second attempt succeeds immediately
+      }
+    }
+    eventually(timeout(10.seconds)) {
+      assert(SparkContextSuite.taskSucceeded)
+    }
+  }
+
   test("SPARK-19446: DebugFilesystem.assertNoOpenStreams should report " +
     "open streams to help debugging") {
     val fs = new DebugFilesystem()
@@ -555,11 +597,12 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
     assert(exc.getCause() != null)
     stream.close()
   }
-
 }
 
 object SparkContextSuite {
   @volatile var cancelJob = false
   @volatile var cancelStage = false
   @volatile var isTaskStarted = false
+  @volatile var taskKilled = false
+  @volatile var taskSucceeded = false
 }
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index 8150fff2d018da61480dbdfdbeb673a5c0d52a5a..f47e574b4fc4b1548f90881df20f76361c81c919 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -110,14 +110,14 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
       }
       // we know the task will be started, but not yet deserialized, because of the latches we
       // use in mockExecutorBackend.
-      executor.killAllTasks(true)
+      executor.killAllTasks(true, "test")
       executorSuiteHelper.latch2.countDown()
       if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) {
         fail("executor did not send second status update in time")
       }
 
       // `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED`
-      assert(executorSuiteHelper.testFailedReason === TaskKilled)
+      assert(executorSuiteHelper.testFailedReason === TaskKilled("test"))
       assert(executorSuiteHelper.taskState === TaskState.KILLED)
     }
     finally {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index a9389003d5db8124649410fdb52aa45b50fad23e..a10941b579fe2f1cad506ada08a2bdc0758541f6 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -126,6 +126,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
     override def cancelTasks(stageId: Int, interruptThread: Boolean) {
       cancelledStages += stageId
     }
+    override def killTaskAttempt(
+      taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
     override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
     override def defaultParallelism() = 2
     override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
@@ -552,6 +554,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
       override def cancelTasks(stageId: Int, interruptThread: Boolean) {
         throw new UnsupportedOperationException
       }
+      override def killTaskAttempt(
+          taskId: Long, interruptThread: Boolean, reason: String): Boolean = {
+        throw new UnsupportedOperationException
+      }
       override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
       override def defaultParallelism(): Int = 2
       override def executorHeartbeatReceived(
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
index 37c124a726be25662c18972fb2c54e940f1579f2..ba56af8215cd7a82bffc43aee6f3e8ed4734f408 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
@@ -79,6 +79,8 @@ private class DummyTaskScheduler extends TaskScheduler {
   override def stop(): Unit = {}
   override def submitTasks(taskSet: TaskSet): Unit = {}
   override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {}
+  override def killTaskAttempt(
+    taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
   override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
   override def defaultParallelism(): Int = 2
   override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
index 38b9d40329d48b8fe0ec3e7b7c039098f5f00567..e51e6a0d3ff6b9a5c78b7918f38115e187b76ad4 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
@@ -176,13 +176,13 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
     assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter))
     // The non-authorized committer fails
     outputCommitCoordinator.taskCompleted(
-      stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled)
+      stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test"))
     // New tasks should still not be able to commit because the authorized committer has not failed
     assert(
       !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1))
     // The authorized committer now fails, clearing the lock
     outputCommitCoordinator.taskCompleted(
-      stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled)
+      stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test"))
     // A new task should now be allowed to become the authorized committer
     assert(
       outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
index 398ac3d6202dbd0c027f3710350e3cde1a29dc88..8103983c4392a2a471d958145f7710d7d615610c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
@@ -410,7 +410,8 @@ private[spark] abstract class MockBackend(
     }
   }
 
-  override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {
+  override def killTask(
+      taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = {
     // We have to implement this b/c of SPARK-15385.
     // Its OK for this to be a no-op, because even if a backend does implement killTask,
     // it really can only be "best-effort" in any case, and the scheduler should be robust to that.
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 064af381a76d2808ac9a074b2aa91ba8573717ac..132caef0978fbad2edc6ac94b93fc53d1270413b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -677,7 +677,11 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
 
     val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"))
     sched.initialize(new FakeSchedulerBackend() {
-      override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {}
+      override def killTask(
+        taskId: Long,
+        executorId: String,
+        interruptThread: Boolean,
+        reason: String): Unit = {}
     })
 
     // Keep track of the number of tasks that are resubmitted,
@@ -935,7 +939,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     // Complete the speculative attempt for the running task
     manager.handleSuccessfulTask(4, createTaskResult(3, accumUpdatesByTask(3)))
     // Verify that it kills other running attempt
-    verify(sched.backend).killTask(3, "exec2", true)
+    verify(sched.backend).killTask(3, "exec2", true, "another attempt succeeded")
     // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be
     // killed, so the FakeTaskScheduler is only told about the successful completion
     // of the speculated task.
@@ -1023,14 +1027,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     manager.handleSuccessfulTask(speculativeTask.taskId, createTaskResult(3, accumUpdatesByTask(3)))
     // Verify that it kills other running attempt
     val origTask = originalTasks(speculativeTask.index)
-    verify(sched.backend).killTask(origTask.taskId, "exec2", true)
+    verify(sched.backend).killTask(origTask.taskId, "exec2", true, "another attempt succeeded")
     // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be
     // killed, so the FakeTaskScheduler is only told about the successful completion
     // of the speculated task.
     assert(sched.endedTasks(3) === Success)
     // also because the scheduler is a mock, our manager isn't notified about the task killed event,
     // so we do that manually
-    manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled)
+    manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled("test"))
     // this task has "failed" 4 times, but one of them doesn't count, so keep running the stage
     assert(manager.tasksSuccessful === 4)
     assert(!manager.isZombie)
@@ -1047,7 +1051,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
       createTaskResult(3, accumUpdatesByTask(3)))
     // Verify that it kills other running attempt
     val origTask2 = originalTasks(speculativeTask2.index)
-    verify(sched.backend).killTask(origTask2.taskId, "exec2", true)
+    verify(sched.backend).killTask(origTask2.taskId, "exec2", true, "another attempt succeeded")
     assert(manager.tasksSuccessful === 5)
     assert(manager.isZombie)
   }
@@ -1102,8 +1106,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
       ExecutorLostFailure(taskDescs(1).executorId, exitCausedByApp = false, reason = None))
     tsmSpy.handleFailedTask(taskDescs(2).taskId, TaskState.FAILED,
       TaskCommitDenied(0, 2, 0))
-    tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED,
-      TaskKilled)
+    tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, TaskKilled("test"))
 
     // Make sure that the blacklist ignored all of the task failures above, since they aren't
     // the fault of the executor where the task was running.
diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
index 6335d905c0fbfca393208fdc34b0eb76216e4468..c770fd5da76f7863860bb08e43dd6dd4898a6353 100644
--- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
@@ -110,7 +110,7 @@ class UIUtilsSuite extends SparkFunSuite {
   }
 
   test("SPARK-11906: Progress bar should not overflow because of speculative tasks") {
-    val generated = makeProgressBar(2, 3, 0, 0, 0, 4).head.child.filter(_.label == "div")
+    val generated = makeProgressBar(2, 3, 0, 0, Map.empty, 4).head.child.filter(_.label == "div")
     val expected = Seq(
       <div class="bar bar-completed" style="width: 75.0%"></div>,
       <div class="bar bar-running" style="width: 25.0%"></div>
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index e3127da9a6b24a7be5109d43c05a91d1ab0a8ce7..93964a2d5674335cd05b652d9fe039eb8e618bc9 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -274,8 +274,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
 
     // Make sure killed tasks are accounted for correctly.
     listener.onTaskEnd(
-      SparkListenerTaskEnd(task.stageId, 0, taskType, TaskKilled, taskInfo, metrics))
-    assert(listener.stageIdToData((task.stageId, 0)).numKilledTasks === 1)
+      SparkListenerTaskEnd(
+        task.stageId, 0, taskType, TaskKilled("test"), taskInfo, metrics))
+    assert(listener.stageIdToData((task.stageId, 0)).reasonToNumKilled === Map("test" -> 1))
 
     // Make sure we count success as success.
     listener.onTaskEnd(
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 9f76c74bce89e2bd5f55bbf8f4f93e6ae3cf554a..a64dbeae472946b7562aaa4c9ea9219c84bde69a 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -164,7 +164,7 @@ class JsonProtocolSuite extends SparkFunSuite {
     testTaskEndReason(fetchMetadataFailed)
     testTaskEndReason(exceptionFailure)
     testTaskEndReason(TaskResultLost)
-    testTaskEndReason(TaskKilled)
+    testTaskEndReason(TaskKilled("test"))
     testTaskEndReason(TaskCommitDenied(2, 3, 4))
     testTaskEndReason(ExecutorLostFailure("100", true, Some("Induced failure")))
     testTaskEndReason(UnknownReason)
@@ -676,7 +676,8 @@ private[spark] object JsonProtocolSuite extends Assertions {
         assert(r1.fullStackTrace === r2.fullStackTrace)
         assertSeqEquals[AccumulableInfo](r1.accumUpdates, r2.accumUpdates, (a, b) => a.equals(b))
       case (TaskResultLost, TaskResultLost) =>
-      case (TaskKilled, TaskKilled) =>
+      case (r1: TaskKilled, r2: TaskKilled) =>
+        assert(r1.reason == r2.reason)
       case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1),
           TaskCommitDenied(jobId2, partitionId2, attemptNumber2)) =>
         assert(jobId1 === jobId2)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 9925a8ba7266247b16866c84d2fe540e2d9b8ccf..8ce9367c9b44666567b37136f8a706a875dd77f9 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -66,6 +66,19 @@ object MimaExcludes {
     // [SPARK-17161] Removing Python-friendly constructors not needed
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this"),
 
+    // [SPARK-19820] Allow reason to be specified to task kill
+    ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.TaskKilled$"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productElement"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productArity"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.canEqual"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productIterator"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.countTowardsTaskFailures"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productPrefix"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.toErrorString"),
+    ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.TaskKilled.toString"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.killTaskIfInterrupted"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getKillReason"),
+
     // [SPARK-19876] Add one time trigger, and improve Trigger APIs
     ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.streaming.Trigger"),
     ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.streaming.ProcessingTime")
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index b25253978258003b23a4e9a1e6469c7834a42292..a086ec7ea2da66999925b0dc16c75a3a4e848840 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -104,7 +104,8 @@ private[spark] class MesosExecutorBackend
       logError("Received KillTask but executor was null")
     } else {
       // TODO: Determine the 'interruptOnCancel' property set for the given job.
-      executor.killTask(t.getValue.toLong, interruptThread = false)
+      executor.killTask(
+        t.getValue.toLong, interruptThread = false, reason = "killed by mesos")
     }
   }
 
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala
index f198f8893b3db2ca7138f6ae4876f48433d1e4a2..735c879c63c55e9713c01c325ff56a966173f6e3 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala
@@ -428,7 +428,8 @@ private[spark] class MesosFineGrainedSchedulerBackend(
     recordSlaveLost(d, slaveId, ExecutorExited(status, exitCausedByApp = true))
   }
 
-  override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {
+  override def killTask(
+      taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = {
     schedulerDriver.killTask(
       TaskID.newBuilder()
         .setValue(taskId.toString).build()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index a89d172a911abec1f2588e96ac6290f7dccb5a16..9df20731c71d5069070ed0492ef1dec069da95bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -101,9 +101,7 @@ class FileScanRDD(
         // Kill the task in case it has been marked as killed. This logic is from
         // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
         // to avoid performance overhead.
-        if (context.isInterrupted()) {
-          throw new TaskKilledException
-        }
+        context.killTaskIfInterrupted()
         (currentIterator != null && currentIterator.hasNext) || nextIterator()
       }
       def next(): Object = {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala
index 1352ca1c4c95f384c2cf2ac20bce2a2ee77ce7fe..70b4bb466c46b188503acdf7517c580c2b9c790a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala
@@ -97,7 +97,7 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long)
         completed = batch.numCompletedOutputOp,
         failed = batch.numFailedOutputOp,
         skipped = 0,
-        killed = 0,
+        reasonToNumKilled = Map.empty,
         total = batch.outputOperations.size)
       }
     </td>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
index 1a87fc790f91b5d5fc3aca5568769fd266bd62a7..f55af6a5cc35899efdade647ff9e7d25f9821078 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
@@ -146,7 +146,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") {
             completed = sparkJob.numCompletedTasks,
             failed = sparkJob.numFailedTasks,
             skipped = sparkJob.numSkippedTasks,
-            killed = sparkJob.numKilledTasks,
+            reasonToNumKilled = sparkJob.reasonToNumKilled,
             total = sparkJob.numTasks - sparkJob.numSkippedTasks)
         }
       </td>