Skip to content
Snippets Groups Projects
Commit 5857b9ac authored by Josh Rosen's avatar Josh Rosen Committed by Herman van Hovell
Browse files

[SPARK-18928] Check TaskContext.isInterrupted() in FileScanRDD, JDBCRDD & UnsafeSorter

## What changes were proposed in this pull request?

In order to respond to task cancellation, Spark tasks must periodically check `TaskContext.isInterrupted()`, but this check is missing on a few critical read paths used in Spark SQL, including `FileScanRDD`, `JDBCRDD`, and UnsafeSorter-based sorts. This can cause interrupted / cancelled tasks to continue running and become zombies (as also described in #16189).

This patch aims to fix this problem by adding `TaskContext.isInterrupted()` checks to these paths. Note that I could have used `InterruptibleIterator` to simply wrap a bunch of iterators but in some cases this would have an adverse performance penalty or might not be effective due to certain special uses of Iterators in Spark SQL. Instead, I inlined `InterruptibleIterator`-style logic into existing iterator subclasses.

## How was this patch tested?

Tested manually in `spark-shell` with two different reproductions of non-cancellable tasks, one involving scans of huge files and another involving sort-merge joins that spill to disk. Both causes of zombie tasks are fixed by the changes added here.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #16340 from JoshRosen/sql-task-interruption.
parent 4cb49412
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,8 @@ import java.util.LinkedList;
import org.apache.avro.reflect.Nullable;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskKilledException;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
......@@ -253,6 +255,7 @@ public final class UnsafeInMemorySorter {
private long keyPrefix;
private int recordLength;
private long currentPageNumber;
private final TaskContext taskContext = TaskContext.get();
private SortedIterator(int numRecords, int offset) {
this.numRecords = numRecords;
......@@ -283,6 +286,14 @@ public final class UnsafeInMemorySorter {
@Override
public void loadNext() {
// Kill the task in case it has been marked as killed. This logic is from
// InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
// to avoid performance overhead. This check is added here in `loadNext()` instead of in
// `hasNext()` because it's technically possible for the caller to be relying on
// `getNumRecords()` instead of `hasNext()` to know when to stop.
if (taskContext != null && taskContext.isInterrupted()) {
throw new TaskKilledException();
}
// This pointer points to a 4-byte record length, followed by the record's bytes
final long recordPointer = array.get(offset + position);
currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer);
......
......@@ -23,6 +23,8 @@ import com.google.common.io.ByteStreams;
import com.google.common.io.Closeables;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskKilledException;
import org.apache.spark.io.NioBufferedFileInputStream;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockId;
......@@ -51,6 +53,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
private byte[] arr = new byte[1024 * 1024];
private Object baseObject = arr;
private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
private final TaskContext taskContext = TaskContext.get();
public UnsafeSorterSpillReader(
SerializerManager serializerManager,
......@@ -94,6 +97,14 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
@Override
public void loadNext() throws IOException {
// Kill the task in case it has been marked as killed. This logic is from
// InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
// to avoid performance overhead. This check is added here in `loadNext()` instead of in
// `hasNext()` because it's technically possible for the caller to be relying on
// `getNumRecords()` instead of `hasNext()` to know when to stop.
if (taskContext != null && taskContext.isInterrupted()) {
throw new TaskKilledException();
}
recordLength = din.readInt();
keyPrefix = din.readLong();
if (recordLength > arr.length) {
......
......@@ -21,7 +21,7 @@ import java.io.IOException
import scala.collection.mutable
import org.apache.spark.{Partition => RDDPartition, TaskContext}
import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{InputFileBlockHolder, RDD}
import org.apache.spark.sql.SparkSession
......@@ -99,7 +99,15 @@ class FileScanRDD(
private[this] var currentFile: PartitionedFile = null
private[this] var currentIterator: Iterator[Object] = null
def hasNext: Boolean = (currentIterator != null && currentIterator.hasNext) || nextIterator()
def hasNext: Boolean = {
// Kill the task in case it has been marked as killed. This logic is from
// InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
// to avoid performance overhead.
if (context.isInterrupted()) {
throw new TaskKilledException
}
(currentIterator != null && currentIterator.hasNext) || nextIterator()
}
def next(): Object = {
val nextElement = currentIterator.next()
// TODO: we should have a better separation of row based and batch based scan, so that we
......
......@@ -23,7 +23,7 @@ import scala.util.control.NonFatal
import org.apache.commons.lang3.StringUtils
import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
......@@ -301,6 +301,7 @@ private[jdbc] class JDBCRDD(
rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, close())
CompletionIterator[InternalRow, Iterator[InternalRow]](
new InterruptibleIterator(context, rowsIterator), close())
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment