Skip to content
Snippets Groups Projects
Commit f07e989c authored by Josh Rosen's avatar Josh Rosen Committed by Herman van Hovell
Browse files

[SPARK-18928] Check TaskContext.isInterrupted() in FileScanRDD, JDBCRDD & UnsafeSorter


## What changes were proposed in this pull request?

In order to respond to task cancellation, Spark tasks must periodically check `TaskContext.isInterrupted()`, but this check is missing on a few critical read paths used in Spark SQL, including `FileScanRDD`, `JDBCRDD`, and UnsafeSorter-based sorts. This can cause interrupted / cancelled tasks to continue running and become zombies (as also described in #16189).

This patch aims to fix this problem by adding `TaskContext.isInterrupted()` checks to these paths. Note that I could have used `InterruptibleIterator` to simply wrap a bunch of iterators but in some cases this would have an adverse performance penalty or might not be effective due to certain special uses of Iterators in Spark SQL. Instead, I inlined `InterruptibleIterator`-style logic into existing iterator subclasses.

## How was this patch tested?

Tested manually in `spark-shell` with two different reproductions of non-cancellable tasks, one involving scans of huge files and another involving sort-merge joins that spill to disk. Both causes of zombie tasks are fixed by the changes added here.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #16340 from JoshRosen/sql-task-interruption.

(cherry picked from commit 5857b9ac)
Signed-off-by: default avatarHerman van Hovell <hvanhovell@databricks.com>
parent c1a26b45
No related branches found
No related tags found
No related merge requests found
...@@ -22,6 +22,8 @@ import java.util.LinkedList; ...@@ -22,6 +22,8 @@ import java.util.LinkedList;
import org.apache.avro.reflect.Nullable; import org.apache.avro.reflect.Nullable;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskKilledException;
import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.Platform;
...@@ -253,6 +255,7 @@ public final class UnsafeInMemorySorter { ...@@ -253,6 +255,7 @@ public final class UnsafeInMemorySorter {
private long keyPrefix; private long keyPrefix;
private int recordLength; private int recordLength;
private long currentPageNumber; private long currentPageNumber;
private final TaskContext taskContext = TaskContext.get();
private SortedIterator(int numRecords, int offset) { private SortedIterator(int numRecords, int offset) {
this.numRecords = numRecords; this.numRecords = numRecords;
...@@ -283,6 +286,14 @@ public final class UnsafeInMemorySorter { ...@@ -283,6 +286,14 @@ public final class UnsafeInMemorySorter {
@Override @Override
public void loadNext() { public void loadNext() {
// Kill the task in case it has been marked as killed. This logic is from
// InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
// to avoid performance overhead. This check is added here in `loadNext()` instead of in
// `hasNext()` because it's technically possible for the caller to be relying on
// `getNumRecords()` instead of `hasNext()` to know when to stop.
if (taskContext != null && taskContext.isInterrupted()) {
throw new TaskKilledException();
}
// This pointer points to a 4-byte record length, followed by the record's bytes // This pointer points to a 4-byte record length, followed by the record's bytes
final long recordPointer = array.get(offset + position); final long recordPointer = array.get(offset + position);
currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer);
......
...@@ -23,6 +23,8 @@ import com.google.common.io.ByteStreams; ...@@ -23,6 +23,8 @@ import com.google.common.io.ByteStreams;
import com.google.common.io.Closeables; import com.google.common.io.Closeables;
import org.apache.spark.SparkEnv; import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskKilledException;
import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.spark.io.NioBufferedFileInputStream;
import org.apache.spark.serializer.SerializerManager; import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockId;
...@@ -51,6 +53,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen ...@@ -51,6 +53,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
private byte[] arr = new byte[1024 * 1024]; private byte[] arr = new byte[1024 * 1024];
private Object baseObject = arr; private Object baseObject = arr;
private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
private final TaskContext taskContext = TaskContext.get();
public UnsafeSorterSpillReader( public UnsafeSorterSpillReader(
SerializerManager serializerManager, SerializerManager serializerManager,
...@@ -94,6 +97,14 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen ...@@ -94,6 +97,14 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
@Override @Override
public void loadNext() throws IOException { public void loadNext() throws IOException {
// Kill the task in case it has been marked as killed. This logic is from
// InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
// to avoid performance overhead. This check is added here in `loadNext()` instead of in
// `hasNext()` because it's technically possible for the caller to be relying on
// `getNumRecords()` instead of `hasNext()` to know when to stop.
if (taskContext != null && taskContext.isInterrupted()) {
throw new TaskKilledException();
}
recordLength = din.readInt(); recordLength = din.readInt();
keyPrefix = din.readLong(); keyPrefix = din.readLong();
if (recordLength > arr.length) { if (recordLength > arr.length) {
......
...@@ -21,7 +21,7 @@ import java.io.IOException ...@@ -21,7 +21,7 @@ import java.io.IOException
import scala.collection.mutable import scala.collection.mutable
import org.apache.spark.{Partition => RDDPartition, TaskContext} import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException}
import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{InputFileNameHolder, RDD} import org.apache.spark.rdd.{InputFileNameHolder, RDD}
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.SparkSession
...@@ -99,7 +99,15 @@ class FileScanRDD( ...@@ -99,7 +99,15 @@ class FileScanRDD(
private[this] var currentFile: PartitionedFile = null private[this] var currentFile: PartitionedFile = null
private[this] var currentIterator: Iterator[Object] = null private[this] var currentIterator: Iterator[Object] = null
def hasNext: Boolean = (currentIterator != null && currentIterator.hasNext) || nextIterator() def hasNext: Boolean = {
// Kill the task in case it has been marked as killed. This logic is from
// InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
// to avoid performance overhead.
if (context.isInterrupted()) {
throw new TaskKilledException
}
(currentIterator != null && currentIterator.hasNext) || nextIterator()
}
def next(): Object = { def next(): Object = {
val nextElement = currentIterator.next() val nextElement = currentIterator.next()
// TODO: we should have a better separation of row based and batch based scan, so that we // TODO: we should have a better separation of row based and batch based scan, so that we
......
...@@ -23,7 +23,7 @@ import scala.util.control.NonFatal ...@@ -23,7 +23,7 @@ import scala.util.control.NonFatal
import org.apache.commons.lang3.StringUtils import org.apache.commons.lang3.StringUtils
import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
...@@ -301,6 +301,7 @@ private[jdbc] class JDBCRDD( ...@@ -301,6 +301,7 @@ private[jdbc] class JDBCRDD(
rs = stmt.executeQuery() rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, close()) CompletionIterator[InternalRow, Iterator[InternalRow]](
new InterruptibleIterator(context, rowsIterator), close())
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment