diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 252a35ec6bdf576ede51ad898ef92dfdc47ab4c4..5b42843717e9037390184620c7737a5dcd626d24 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -22,6 +22,8 @@ import java.util.LinkedList;
 
 import org.apache.avro.reflect.Nullable;
 
+import org.apache.spark.TaskContext;
+import org.apache.spark.TaskKilledException;
 import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.unsafe.Platform;
@@ -253,6 +255,7 @@ public final class UnsafeInMemorySorter {
     private long keyPrefix;
     private int recordLength;
     private long currentPageNumber;
+    private final TaskContext taskContext = TaskContext.get();
 
     private SortedIterator(int numRecords, int offset) {
       this.numRecords = numRecords;
@@ -283,6 +286,14 @@ public final class UnsafeInMemorySorter {
 
     @Override
     public void loadNext() {
+      // Kill the task in case it has been marked as killed. This logic is from
+      // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+      // to avoid performance overhead. This check is added here in `loadNext()` instead of in
+      // `hasNext()` because it's technically possible for the caller to be relying on
+      // `getNumRecords()` instead of `hasNext()` to know when to stop.
+      if (taskContext != null && taskContext.isInterrupted()) {
+        throw new TaskKilledException();
+      }
       // This pointer points to a 4-byte record length, followed by the record's bytes
       final long recordPointer = array.get(offset + position);
       currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer);
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index a658e5eb47b781c47b5b93da3f0fd8da0d982f5d..b6323c624b7b96637679984231c6d11cac523353 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -23,6 +23,8 @@ import com.google.common.io.ByteStreams;
 import com.google.common.io.Closeables;
 
 import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
+import org.apache.spark.TaskKilledException;
 import org.apache.spark.io.NioBufferedFileInputStream;
 import org.apache.spark.serializer.SerializerManager;
 import org.apache.spark.storage.BlockId;
@@ -51,6 +53,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
   private byte[] arr = new byte[1024 * 1024];
   private Object baseObject = arr;
   private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
+  private final TaskContext taskContext = TaskContext.get();
 
   public UnsafeSorterSpillReader(
       SerializerManager serializerManager,
@@ -94,6 +97,14 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
 
   @Override
   public void loadNext() throws IOException {
+    // Kill the task in case it has been marked as killed. This logic is from
+    // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+    // to avoid performance overhead. This check is added here in `loadNext()` instead of in
+    // `hasNext()` because it's technically possible for the caller to be relying on
+    // `getNumRecords()` instead of `hasNext()` to know when to stop.
+    if (taskContext != null && taskContext.isInterrupted()) {
+      throw new TaskKilledException();
+    }
     recordLength = din.readInt();
     keyPrefix = din.readLong();
     if (recordLength > arr.length) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index 69338f7d966153bee4513691a63b0161ecf14d15..b926b9207416fe5a2a4380ab17feba3b78996218 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -21,7 +21,7 @@ import java.io.IOException
 
 import scala.collection.mutable
 
-import org.apache.spark.{Partition => RDDPartition, TaskContext}
+import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException}
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.rdd.{InputFileNameHolder, RDD}
 import org.apache.spark.sql.SparkSession
@@ -99,7 +99,15 @@ class FileScanRDD(
       private[this] var currentFile: PartitionedFile = null
       private[this] var currentIterator: Iterator[Object] = null
 
-      def hasNext: Boolean = (currentIterator != null && currentIterator.hasNext) || nextIterator()
+      def hasNext: Boolean = {
+        // Kill the task in case it has been marked as killed. This logic is from
+        // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+        // to avoid performance overhead.
+        if (context.isInterrupted()) {
+          throw new TaskKilledException
+        }
+        (currentIterator != null && currentIterator.hasNext) || nextIterator()
+      }
       def next(): Object = {
         val nextElement = currentIterator.next()
         // TODO: we should have a better separation of row based and batch based scan, so that we
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index d5b11e7bec0bd2187d2a4cec96e4cd25d78dfe97..2bdc43254133e11586de5010d887e4a46253fd68 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -23,7 +23,7 @@ import scala.util.control.NonFatal
 
 import org.apache.commons.lang3.StringUtils
 
-import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -301,6 +301,7 @@ private[jdbc] class JDBCRDD(
     rs = stmt.executeQuery()
     val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
 
-    CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, close())
+    CompletionIterator[InternalRow, Iterator[InternalRow]](
+      new InterruptibleIterator(context, rowsIterator), close())
   }
 }