diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 81ee7ab58ab5b935927a7e1f5a234721cf5119cd..3c2980e442ab7d989ddc9d74f780130e2fde1673 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -215,8 +215,6 @@ final class ShuffleExternalSorter extends MemoryConsumer {
       }
     }
 
-    inMemSorter.reset();
-
     if (!isLastFile) {  // i.e. this is a spill file
       // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
       // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
@@ -255,6 +253,10 @@ final class ShuffleExternalSorter extends MemoryConsumer {
 
     writeSortedFile(false);
     final long spillSize = freeMemory();
+    inMemSorter.reset();
+    // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the
+    // records. Otherwise, if the task is over allocated memory, then without freeing the memory pages,
+    // we might not be able to get memory for the pointer array.
     taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
     return spillSize;
   }
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index fe79ff0e3052b63687444fc2a09e60cbcc5f34d1..76b0e6a304ac21b4cd78d9f4c7ff26464dd3b5c6 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -51,9 +51,12 @@ final class ShuffleInMemorySorter {
    */
   private int pos = 0;
 
+  private int initialSize;
+
   ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) {
     this.consumer = consumer;
     assert (initialSize > 0);
+    this.initialSize = initialSize;
     this.array = consumer.allocateArray(initialSize);
     this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
   }
@@ -70,6 +73,10 @@ final class ShuffleInMemorySorter {
   }
 
   public void reset() {
+    if (consumer != null) {
+      consumer.freeArray(array);
+      this.array = consumer.allocateArray(initialSize);
+    }
     pos = 0;
   }
 
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index ded8f0472b2759bdbf0e745d61211400e803343f..ef79b4908347913d158fb0d75a09b81bb86e91ba 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -200,14 +200,17 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
         spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
       }
       spillWriter.close();
-
-      inMemSorter.reset();
     }
 
     final long spillSize = freeMemory();
     // Note that this is more-or-less going to be a multiple of the page size, so wasted space in
     // pages will currently be counted as memory spilled even though that space isn't actually
     // written to disk. This also counts the space needed to store the sorter's pointer array.
+    inMemSorter.reset();
+    // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the
+    // records. Otherwise, if the task is over allocated memory, then without freeing the memory pages,
+    // we might not be able to get memory for the pointer array.
+
     taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
 
     return spillSize;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 145c3a19506475e3254a74d431a32a5a0eca1c55..01eae0e8dc14caff7684c0bda202381890e2a166 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -84,6 +84,8 @@ public final class UnsafeInMemorySorter {
    */
   private int pos = 0;
 
+  private long initialSize;
+
   public UnsafeInMemorySorter(
     final MemoryConsumer consumer,
     final TaskMemoryManager memoryManager,
@@ -102,6 +104,7 @@ public final class UnsafeInMemorySorter {
       LongArray array) {
     this.consumer = consumer;
     this.memoryManager = memoryManager;
+    this.initialSize = array.size();
     if (recordComparator != null) {
       this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
       this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
@@ -123,6 +126,10 @@ public final class UnsafeInMemorySorter {
   }
 
   public void reset() {
+    if (consumer != null) {
+      consumer.freeArray(array);
+      this.array = consumer.allocateArray(initialSize);
+    }
     pos = 0;
   }