diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 9a7b2ad06cab68428b6dc8a376f776ae1e8c36e2..2e4031267473781577ea0b5d896242ec41f75f4f 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -468,6 +468,12 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
           }
           allocatedPages.clear();
         }
+
+        // in-memory sorter will not be used after spilling
+        assert(inMemSorter != null);
+        released += inMemSorter.getMemoryUsage();
+        inMemSorter.free();
+        inMemSorter = null;
         return released;
       }
     }
@@ -489,10 +495,6 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
           }
           upstream = nextUpstream;
           nextUpstream = null;
-
-          assert(inMemSorter != null);
-          inMemSorter.free();
-          inMemSorter = null;
         }
         numRecords--;
         upstream.loadNext();
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index a218ad4623f463e6e850fa9d305aca088ca14ecb..dce1f15a2963c4ee96641ab79cfd363ac3b044eb 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -108,6 +108,7 @@ public final class UnsafeInMemorySorter {
    */
   public void free() {
     consumer.freeArray(array);
+    array = null;
   }
 
   public void reset() {
@@ -160,28 +161,22 @@ public final class UnsafeInMemorySorter {
     pos++;
   }
 
-  public static final class SortedIterator extends UnsafeSorterIterator {
+  public final class SortedIterator extends UnsafeSorterIterator {
 
-    private final TaskMemoryManager memoryManager;
-    private final int sortBufferInsertPosition;
-    private final LongArray sortBuffer;
-    private int position = 0;
+    private final int numRecords;
+    private int position;
     private Object baseObject;
     private long baseOffset;
     private long keyPrefix;
     private int recordLength;
 
-    private SortedIterator(
-        TaskMemoryManager memoryManager,
-        int sortBufferInsertPosition,
-        LongArray sortBuffer) {
-      this.memoryManager = memoryManager;
-      this.sortBufferInsertPosition = sortBufferInsertPosition;
-      this.sortBuffer = sortBuffer;
+    private SortedIterator(int numRecords) {
+      this.numRecords = numRecords;
+      this.position = 0;
     }
 
     public SortedIterator clone () {
-      SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer);
+      SortedIterator iter = new SortedIterator(numRecords);
       iter.position = position;
       iter.baseObject = baseObject;
       iter.baseOffset = baseOffset;
@@ -192,21 +187,21 @@ public final class UnsafeInMemorySorter {
 
     @Override
     public boolean hasNext() {
-      return position < sortBufferInsertPosition;
+      return position / 2 < numRecords;
     }
 
     public int numRecordsLeft() {
-      return (sortBufferInsertPosition - position) / 2;
+      return numRecords - position / 2;
     }
 
     @Override
     public void loadNext() {
       // This pointer points to a 4-byte record length, followed by the record's bytes
-      final long recordPointer = sortBuffer.get(position);
+      final long recordPointer = array.get(position);
       baseObject = memoryManager.getPage(recordPointer);
       baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4;  // Skip over record length
       recordLength = Platform.getInt(baseObject, baseOffset - 4);
-      keyPrefix = sortBuffer.get(position + 1);
+      keyPrefix = array.get(position + 1);
       position += 2;
     }
 
@@ -229,6 +224,6 @@ public final class UnsafeInMemorySorter {
    */
   public SortedIterator getSortedIterator() {
     sorter.sort(array, 0, pos / 2, sortComparator);
-    return new SortedIterator(memoryManager, pos, array);
+    return new SortedIterator(pos / 2);
   }
 }