From f07e989c02844151587f9a29fe77ea65facea422 Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@databricks.com>
Date: Tue, 20 Dec 2016 01:19:38 +0100
Subject: [PATCH] [SPARK-18928] Check TaskContext.isInterrupted() in
 FileScanRDD, JDBCRDD & UnsafeSorter

## What changes were proposed in this pull request?

In order to respond to task cancellation, Spark tasks must periodically check `TaskContext.isInterrupted()`, but this check is missing on a few critical read paths used in Spark SQL, including `FileScanRDD`, `JDBCRDD`, and UnsafeSorter-based sorts. This can cause interrupted / cancelled tasks to continue running and become zombies (as also described in #16189).

This patch aims to fix this problem by adding `TaskContext.isInterrupted()` checks to these paths. Note that I could have used `InterruptibleIterator` to simply wrap a bunch of iterators but in some cases this would have an adverse performance penalty or might not be effective due to certain special uses of Iterators in Spark SQL. Instead, I inlined `InterruptibleIterator`-style logic into existing iterator subclasses.

## How was this patch tested?

Tested manually in `spark-shell` with two different reproductions of non-cancellable tasks, one involving scans of huge files and another involving sort-merge joins that spill to disk. Both causes of zombie tasks are fixed by the changes added here.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #16340 from JoshRosen/sql-task-interruption.

(cherry picked from commit 5857b9ac2d9808d9b89a5b29620b5052e2beebf5)
Signed-off-by: Herman van Hovell <hvanhovell@databricks.com>
---
 .../collection/unsafe/sort/UnsafeInMemorySorter.java | 11 +++++++++++
 .../unsafe/sort/UnsafeSorterSpillReader.java         | 11 +++++++++++
 .../sql/execution/datasources/FileScanRDD.scala      | 12 ++++++++++--
 .../sql/execution/datasources/jdbc/JDBCRDD.scala     |  5 +++--
 4 files changed, 35 insertions(+), 4 deletions(-)

diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 252a35ec6b..5b42843717 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -22,6 +22,8 @@ import java.util.LinkedList;
 
 import org.apache.avro.reflect.Nullable;
 
+import org.apache.spark.TaskContext;
+import org.apache.spark.TaskKilledException;
 import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.unsafe.Platform;
@@ -253,6 +255,7 @@ public final class UnsafeInMemorySorter {
     private long keyPrefix;
     private int recordLength;
     private long currentPageNumber;
+    private final TaskContext taskContext = TaskContext.get();
 
     private SortedIterator(int numRecords, int offset) {
       this.numRecords = numRecords;
@@ -283,6 +286,14 @@ public final class UnsafeInMemorySorter {
 
     @Override
     public void loadNext() {
+      // Kill the task in case it has been marked as killed. This logic is from
+      // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+      // to avoid performance overhead. This check is added here in `loadNext()` instead of in
+      // `hasNext()` because it's technically possible for the caller to be relying on
+      // `getNumRecords()` instead of `hasNext()` to know when to stop.
+      if (taskContext != null && taskContext.isInterrupted()) {
+        throw new TaskKilledException();
+      }
       // This pointer points to a 4-byte record length, followed by the record's bytes
       final long recordPointer = array.get(offset + position);
       currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer);
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index a658e5eb47..b6323c624b 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -23,6 +23,8 @@ import com.google.common.io.ByteStreams;
 import com.google.common.io.Closeables;
 
 import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
+import org.apache.spark.TaskKilledException;
 import org.apache.spark.io.NioBufferedFileInputStream;
 import org.apache.spark.serializer.SerializerManager;
 import org.apache.spark.storage.BlockId;
@@ -51,6 +53,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
   private byte[] arr = new byte[1024 * 1024];
   private Object baseObject = arr;
   private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
+  private final TaskContext taskContext = TaskContext.get();
 
   public UnsafeSorterSpillReader(
       SerializerManager serializerManager,
@@ -94,6 +97,14 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
 
   @Override
   public void loadNext() throws IOException {
+    // Kill the task in case it has been marked as killed. This logic is from
+    // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+    // to avoid performance overhead. This check is added here in `loadNext()` instead of in
+    // `hasNext()` because it's technically possible for the caller to be relying on
+    // `getNumRecords()` instead of `hasNext()` to know when to stop.
+    if (taskContext != null && taskContext.isInterrupted()) {
+      throw new TaskKilledException();
+    }
     recordLength = din.readInt();
     keyPrefix = din.readLong();
     if (recordLength > arr.length) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index 69338f7d96..b926b92074 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -21,7 +21,7 @@ import java.io.IOException
 
 import scala.collection.mutable
 
-import org.apache.spark.{Partition => RDDPartition, TaskContext}
+import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException}
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.rdd.{InputFileNameHolder, RDD}
 import org.apache.spark.sql.SparkSession
@@ -99,7 +99,15 @@ class FileScanRDD(
       private[this] var currentFile: PartitionedFile = null
       private[this] var currentIterator: Iterator[Object] = null
 
-      def hasNext: Boolean = (currentIterator != null && currentIterator.hasNext) || nextIterator()
+      def hasNext: Boolean = {
+        // Kill the task in case it has been marked as killed. This logic is from
+        // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+        // to avoid performance overhead.
+        if (context.isInterrupted()) {
+          throw new TaskKilledException
+        }
+        (currentIterator != null && currentIterator.hasNext) || nextIterator()
+      }
       def next(): Object = {
         val nextElement = currentIterator.next()
         // TODO: we should have a better separation of row based and batch based scan, so that we
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index d5b11e7bec..2bdc432541 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -23,7 +23,7 @@ import scala.util.control.NonFatal
 
 import org.apache.commons.lang3.StringUtils
 
-import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -301,6 +301,7 @@ private[jdbc] class JDBCRDD(
     rs = stmt.executeQuery()
     val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
 
-    CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, close())
+    CompletionIterator[InternalRow, Iterator[InternalRow]](
+      new InterruptibleIterator(context, rowsIterator), close())
   }
 }
-- 
GitLab