diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9501dd9cd8e932fe6ad50092ea5c3680d717e617..3346f6dd1f975fe687a65a460efe1b3cbf1f0ee1 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -84,6 +84,16 @@ private[spark] class Executor( // Start worker thread pool private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker") private val executorSource = new ExecutorSource(threadPool, executorId) + // Pool used for threads that supervise task killing / cancellation + private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper") + // For tasks which are in the process of being killed, this map holds the most recently created + // TaskReaper. All accesses to this map should be synchronized on the map itself (this isn't + // a ConcurrentHashMap because we use the synchronization for purposes other than simply guarding + // the integrity of the map's internal state). The purpose of this map is to prevent the creation + // of a separate TaskReaper for every killTask() of a given task. Instead, this map allows us to + // track whether an existing TaskReaper fulfills the role of a TaskReaper that we would otherwise + // create. The map key is a task id. + private val taskReaperForTask: HashMap[Long, TaskReaper] = HashMap[Long, TaskReaper]() if (!isLocal) { env.metricsSystem.registerSource(executorSource) @@ -93,6 +103,9 @@ private[spark] class Executor( // Whether to load classes in user jars before those in Spark jars private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) + // Whether to monitor killed / interrupted tasks + private val taskReaperEnabled = conf.getBoolean("spark.task.reaper.enabled", false) + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() @@ -148,9 +161,27 @@ private[spark] class Executor( } def killTask(taskId: Long, interruptThread: Boolean): Unit = { - val tr = runningTasks.get(taskId) - if (tr != null) { - tr.kill(interruptThread) + val taskRunner = runningTasks.get(taskId) + if (taskRunner != null) { + if (taskReaperEnabled) { + val maybeNewTaskReaper: Option[TaskReaper] = taskReaperForTask.synchronized { + val shouldCreateReaper = taskReaperForTask.get(taskId) match { + case None => true + case Some(existingReaper) => interruptThread && !existingReaper.interruptThread + } + if (shouldCreateReaper) { + val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread) + taskReaperForTask(taskId) = taskReaper + Some(taskReaper) + } else { + None + } + } + // Execute the TaskReaper from outside of the synchronized block. + maybeNewTaskReaper.foreach(taskReaperPool.execute) + } else { + taskRunner.kill(interruptThread = interruptThread) + } } } @@ -161,12 +192,7 @@ private[spark] class Executor( * @param interruptThread whether to interrupt the task thread */ def killAllTasks(interruptThread: Boolean) : Unit = { - // kill all the running tasks - for (taskRunner <- runningTasks.values().asScala) { - if (taskRunner != null) { - taskRunner.kill(interruptThread) - } - } + runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread)) } def stop(): Unit = { @@ -192,13 +218,21 @@ private[spark] class Executor( serializedTask: ByteBuffer) extends Runnable { + val threadName = s"Executor task launch worker for task $taskId" + /** Whether this task has been killed. */ @volatile private var killed = false + @volatile private var threadId: Long = -1 + + def getThreadId: Long = threadId + /** Whether this task has been finished. */ @GuardedBy("TaskRunner.this") private var finished = false + def isFinished: Boolean = synchronized { finished } + /** How much the JVM process has spent in GC when the task starts to run. */ @volatile var startGCTime: Long = _ @@ -229,9 +263,15 @@ private[spark] class Executor( // ClosedByInterruptException during execBackend.statusUpdate which causes // Executor to crash Thread.interrupted() + // Notify any waiting TaskReapers. Generally there will only be one reaper per task but there + // is a rare corner-case where one task can have two reapers in case cancel(interrupt=False) + // is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup: + notifyAll() } override def run(): Unit = { + threadId = Thread.currentThread.getId + Thread.currentThread.setName(threadName) val threadMXBean = ManagementFactory.getThreadMXBean val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) val deserializeStartTime = System.currentTimeMillis() @@ -431,6 +471,117 @@ private[spark] class Executor( } } + /** + * Supervises the killing / cancellation of a task by sending the interrupted flag, optionally + * sending a Thread.interrupt(), and monitoring the task until it finishes. + * + * Spark's current task cancellation / task killing mechanism is "best effort" because some tasks + * may not be interruptable or may not respond to their "killed" flags being set. If a significant + * fraction of a cluster's task slots are occupied by tasks that have been marked as killed but + * remain running then this can lead to a situation where new jobs and tasks are starved of + * resources that are being used by these zombie tasks. + * + * The TaskReaper was introduced in SPARK-18761 as a mechanism to monitor and clean up zombie + * tasks. For backwards-compatibility / backportability this component is disabled by default + * and must be explicitly enabled by setting `spark.task.reaper.enabled=true`. + * + * A TaskReaper is created for a particular task when that task is killed / cancelled. Typically + * a task will have only one TaskReaper, but it's possible for a task to have up to two reapers + * in case kill is called twice with different values for the `interrupt` parameter. + * + * Once created, a TaskReaper will run until its supervised task has finished running. If the + * TaskReaper has not been configured to kill the JVM after a timeout (i.e. if + * `spark.task.reaper.killTimeout < 0`) then this implies that the TaskReaper may run indefinitely + * if the supervised task never exits. + */ + private class TaskReaper( + taskRunner: TaskRunner, + val interruptThread: Boolean) + extends Runnable { + + private[this] val taskId: Long = taskRunner.taskId + + private[this] val killPollingIntervalMs: Long = + conf.getTimeAsMs("spark.task.reaper.pollingInterval", "10s") + + private[this] val killTimeoutMs: Long = conf.getTimeAsMs("spark.task.reaper.killTimeout", "-1") + + private[this] val takeThreadDump: Boolean = + conf.getBoolean("spark.task.reaper.threadDump", true) + + override def run(): Unit = { + val startTimeMs = System.currentTimeMillis() + def elapsedTimeMs = System.currentTimeMillis() - startTimeMs + def timeoutExceeded(): Boolean = killTimeoutMs > 0 && elapsedTimeMs > killTimeoutMs + try { + // Only attempt to kill the task once. If interruptThread = false then a second kill + // attempt would be a no-op and if interruptThread = true then it may not be safe or + // effective to interrupt multiple times: + taskRunner.kill(interruptThread = interruptThread) + // Monitor the killed task until it exits. The synchronization logic here is complicated + // because we don't want to synchronize on the taskRunner while possibly taking a thread + // dump, but we also need to be careful to avoid races between checking whether the task + // has finished and wait()ing for it to finish. + var finished: Boolean = false + while (!finished && !timeoutExceeded()) { + taskRunner.synchronized { + // We need to synchronize on the TaskRunner while checking whether the task has + // finished in order to avoid a race where the task is marked as finished right after + // we check and before we call wait(). + if (taskRunner.isFinished) { + finished = true + } else { + taskRunner.wait(killPollingIntervalMs) + } + } + if (taskRunner.isFinished) { + finished = true + } else { + logWarning(s"Killed task $taskId is still running after $elapsedTimeMs ms") + if (takeThreadDump) { + try { + Utils.getThreadDumpForThread(taskRunner.getThreadId).foreach { thread => + if (thread.threadName == taskRunner.threadName) { + logWarning(s"Thread dump from task $taskId:\n${thread.stackTrace}") + } + } + } catch { + case NonFatal(e) => + logWarning("Exception thrown while obtaining thread dump: ", e) + } + } + } + } + + if (!taskRunner.isFinished && timeoutExceeded()) { + if (isLocal) { + logError(s"Killed task $taskId could not be stopped within $killTimeoutMs ms; " + + "not killing JVM because we are running in local mode.") + } else { + // In non-local-mode, the exception thrown here will bubble up to the uncaught exception + // handler and cause the executor JVM to exit. + throw new SparkException( + s"Killing executor JVM because killed task $taskId could not be stopped within " + + s"$killTimeoutMs ms.") + } + } + } finally { + // Clean up entries in the taskReaperForTask map. + taskReaperForTask.synchronized { + taskReaperForTask.get(taskId).foreach { taskReaperInMap => + if (taskReaperInMap eq this) { + taskReaperForTask.remove(taskId) + } else { + // This must have been a TaskReaper where interruptThread == false where a subsequent + // killTask() call for the same task had interruptThread == true and overwrote the + // map entry. + } + } + } + } + } + } + /** * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes * created by the interpreter to the search path diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 071515134503fc15dbb1a02fb391bf622b7a0560..1319a4ce26f5600b67135e7b2baec86d8572d7dc 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.io._ -import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo} +import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels @@ -2117,28 +2117,46 @@ private[spark] object Utils extends Logging { // We need to filter out null values here because dumpAllThreads() may return null array // elements for threads that are dead / don't exist. val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) - threadInfos.sortBy(_.getThreadId).map { case threadInfo => - val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap - val stackTrace = threadInfo.getStackTrace.map { frame => - monitors.get(frame) match { - case Some(monitor) => - monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" - case None => - frame.toString - } - }.mkString("\n") - - // use a set to dedup re-entrant locks that are held at multiple places - val heldLocks = (threadInfo.getLockedSynchronizers.map(_.lockString) - ++ threadInfo.getLockedMonitors.map(_.lockString) - ).toSet + threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace) + } - ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, threadInfo.getThreadState, - stackTrace, if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId), - Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), heldLocks.toSeq) + def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = { + if (threadId <= 0) { + None + } else { + // The Int.MaxValue here requests the entire untruncated stack trace of the thread: + val threadInfo = + Option(ManagementFactory.getThreadMXBean.getThreadInfo(threadId, Int.MaxValue)) + threadInfo.map(threadInfoToThreadStackTrace) } } + private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = { + val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap + val stackTrace = threadInfo.getStackTrace.map { frame => + monitors.get(frame) match { + case Some(monitor) => + monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" + case None => + frame.toString + } + }.mkString("\n") + + // use a set to dedup re-entrant locks that are held at multiple places + val heldLocks = + (threadInfo.getLockedSynchronizers ++ threadInfo.getLockedMonitors).map(_.lockString).toSet + + ThreadStackTrace( + threadId = threadInfo.getThreadId, + threadName = threadInfo.getThreadName, + threadState = threadInfo.getThreadState, + stackTrace = stackTrace, + blockedByThreadId = + if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId), + blockedByLock = Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), + holdingLocks = heldLocks.toSeq) + } + /** * Convert all spark properties set in the given SparkConf to a sequence of java options. */ diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index a3490fc79e458c0593c77235d7275c1eb53e1164..99150a1430d957dfe14f2a6392aac86775c54f58 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -209,6 +209,83 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft assert(jobB.get() === 100) } + test("task reaper kills JVM if killed tasks keep running for too long") { + val conf = new SparkConf() + .set("spark.task.reaper.enabled", "true") + .set("spark.task.reaper.killTimeout", "5s") + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) + + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + + // jobA is the one to be cancelled. + val jobA = Future { + sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true) + sc.parallelize(1 to 10000, 2).map { i => + while (true) { } + }.count() + } + + // Block until both tasks of job A have started and cancel job A. + sem.acquire(2) + // Small delay to ensure tasks actually start executing the task body + Thread.sleep(1000) + + sc.clearJobGroup() + val jobB = sc.parallelize(1 to 100, 2).countAsync() + sc.cancelJobGroup("jobA") + val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 15.seconds) }.getCause + assert(e.getMessage contains "cancel") + + // Once A is cancelled, job B should finish fairly quickly. + assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100) + } + + test("task reaper will not kill JVM if spark.task.killTimeout == -1") { + val conf = new SparkConf() + .set("spark.task.reaper.enabled", "true") + .set("spark.task.reaper.killTimeout", "-1") + .set("spark.task.reaper.PollingInterval", "1s") + .set("spark.deploy.maxExecutorRetries", "1") + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) + + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + + // jobA is the one to be cancelled. + val jobA = Future { + sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true) + sc.parallelize(1 to 2, 2).map { i => + val startTime = System.currentTimeMillis() + while (System.currentTimeMillis() < startTime + 10000) { } + }.count() + } + + // Block until both tasks of job A have started and cancel job A. + sem.acquire(2) + // Small delay to ensure tasks actually start executing the task body + Thread.sleep(1000) + + sc.clearJobGroup() + val jobB = sc.parallelize(1 to 100, 2).countAsync() + sc.cancelJobGroup("jobA") + val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 15.seconds) }.getCause + assert(e.getMessage contains "cancel") + + // Once A is cancelled, job B should finish fairly quickly. + assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100) + } + test("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched // twoJobsSharingStageSemaphore: diff --git a/docs/configuration.md b/docs/configuration.md index e33af3abc09d4f1d66d2c8c9581f70cd1e7ec015..9c325b653e52d4018659c988ce9fc3ee7dbfd399 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1366,6 +1366,48 @@ Apart from these, the following properties are also available, and may be useful Should be greater than or equal to 1. Number of allowed retries = this value - 1. </td> </tr> +<tr> + <td><code>spark.task.reaper.enabled</code></td> + <td>false</td> + <td> + Enables monitoring of killed / interrupted tasks. When set to true, any task which is killed + will be monitored by the executor until that task actually finishes executing. See the other + <code>spark.task.reaper.*</code> configurations for details on how to control the exact behavior + of this monitoring</code>. When set to false (the default), task killing will use an older code + path which lacks such monitoring. + </td> +</tr> +<tr> + <td><code>spark.task.reaper.pollingInterval</code></td> + <td>10s</td> + <td> + When <code>spark.task.reaper.enabled = true</code>, this setting controls the frequency at which + executors will poll the status of killed tasks. If a killed task is still running when polled + then a warning will be logged and, by default, a thread-dump of the task will be logged + (this thread dump can be disabled via the <code>spark.task.reaper.threadDump</code> setting, + which is documented below). + </td> +</tr> +<tr> + <td><code>spark.task.reaper.threadDump</code></td> + <td>true</td> + <td> + When <code>spark.task.reaper.enabled = true</code>, this setting controls whether task thread + dumps are logged during periodic polling of killed tasks. Set this to false to disable + collection of thread dumps. + </td> +</tr> +<tr> + <td><code>spark.task.reaper.killTimeout</code></td> + <td>-1</td> + <td> + When <code>spark.task.reaper.enabled = true</code>, this setting specifies a timeout after + which the executor JVM will kill itself if a killed task has not stopped running. The default + value, -1, disables this mechanism and prevents the executor from self-destructing. The purpose + of this setting is to act as a safety-net to prevent runaway uncancellable tasks from rendering + an executor unusable. + </td> +</tr> </table> #### Dynamic Allocation