diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 5a97f4f11340c65cceafa55905cafe28aa494944..79d74b23ceaef4877dc4cac9f355cd29fc81f7c4 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -443,6 +443,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
         UnsafeInMemorySorter.SortedIterator inMemIterator =
           ((UnsafeInMemorySorter.SortedIterator) upstream).clone();
 
+        // Iterate over the records that have not been returned and spill them.
         final UnsafeSorterSpillWriter spillWriter =
           new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords);
         while (inMemIterator.hasNext()) {
@@ -458,9 +459,11 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
 
         long released = 0L;
         synchronized (UnsafeExternalSorter.this) {
-          // release the pages except the one that is used
+          // release the pages except the one that is used. There can still be a caller that
+          // is accessing the current record. We free this page in that caller's next loadNext()
+          // call.
           for (MemoryBlock page : allocatedPages) {
-            if (!loaded || page.getBaseObject() != inMemIterator.getBaseObject()) {
+            if (!loaded || page.getBaseObject() != upstream.getBaseObject()) {
               released += page.size();
               freePage(page);
             } else {