diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 20347433e16b2285f7b2db1277053461dbe1676b..5ac3736ac62aaff8b3c88d6f1b05f0607f8e7ae5 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -227,22 +227,35 @@ public final class BytesToBytesMap {
     private final Iterator<MemoryBlock> dataPagesIterator;
     private final Location loc;
 
-    private MemoryBlock currentPage;
+    private MemoryBlock currentPage = null;
     private int currentRecordNumber = 0;
     private Object pageBaseObject;
     private long offsetInPage;
 
+    // If this iterator destructive or not. When it is true, it frees each page as it moves onto
+    // next one.
+    private boolean destructive = false;
+    private BytesToBytesMap bmap;
+
     private BytesToBytesMapIterator(
-        int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location loc) {
+        int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location loc,
+        boolean destructive, BytesToBytesMap bmap) {
       this.numRecords = numRecords;
       this.dataPagesIterator = dataPagesIterator;
       this.loc = loc;
+      this.destructive = destructive;
+      this.bmap = bmap;
       if (dataPagesIterator.hasNext()) {
         advanceToNextPage();
       }
     }
 
     private void advanceToNextPage() {
+      if (destructive && currentPage != null) {
+        dataPagesIterator.remove();
+        this.bmap.taskMemoryManager.freePage(currentPage);
+        this.bmap.shuffleMemoryManager.release(currentPage.size());
+      }
       currentPage = dataPagesIterator.next();
       pageBaseObject = currentPage.getBaseObject();
       offsetInPage = currentPage.getBaseOffset();
@@ -281,7 +294,21 @@ public final class BytesToBytesMap {
    * `lookup()`, the behavior of the returned iterator is undefined.
    */
   public BytesToBytesMapIterator iterator() {
-    return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc);
+    return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, false, this);
+  }
+
+  /**
+   * Returns a destructive iterator for iterating over the entries of this map. It frees each page
+   * as it moves onto next one. Notice: it is illegal to call any method on the map after
+   * `destructiveIterator()` has been called.
+   *
+   * For efficiency, all calls to `next()` will return the same {@link Location} object.
+   *
+   * If any other lookups or operations are performed on this map while iterating over it, including
+   * `lookup()`, the behavior of the returned iterator is undefined.
+   */
+  public BytesToBytesMapIterator destructiveIterator() {
+    return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, true, this);
   }
 
   /**
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 0e23a64fb74bb62c5ec9e8d9d982945738e7aa87..3c5003380162fbdc19bc06a16e6cd3af2e72a5c2 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -183,8 +183,7 @@ public abstract class AbstractBytesToBytesMapSuite {
     }
   }
 
-  @Test
-  public void iteratorTest() throws Exception {
+  private void iteratorTestBase(boolean destructive) throws Exception {
     final int size = 4096;
     BytesToBytesMap map = new BytesToBytesMap(
       taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES);
@@ -216,7 +215,14 @@ public abstract class AbstractBytesToBytesMapSuite {
         }
       }
       final java.util.BitSet valuesSeen = new java.util.BitSet(size);
-      final Iterator<BytesToBytesMap.Location> iter = map.iterator();
+      final Iterator<BytesToBytesMap.Location> iter;
+      if (destructive) {
+        iter = map.destructiveIterator();
+      } else {
+        iter = map.iterator();
+      }
+      int numPages = map.getNumDataPages();
+      int countFreedPages = 0;
       while (iter.hasNext()) {
         final BytesToBytesMap.Location loc = iter.next();
         Assert.assertTrue(loc.isDefined());
@@ -228,11 +234,22 @@ public abstract class AbstractBytesToBytesMapSuite {
         if (keyLength == 0) {
           Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0);
         } else {
-        final long key = PlatformDependent.UNSAFE.getLong(
-          keyAddress.getBaseObject(), keyAddress.getBaseOffset());
+          final long key = PlatformDependent.UNSAFE.getLong(
+            keyAddress.getBaseObject(), keyAddress.getBaseOffset());
           Assert.assertEquals(value, key);
         }
         valuesSeen.set((int) value);
+        if (destructive) {
+          // The iterator moves onto next page and frees previous page
+          if (map.getNumDataPages() < numPages) {
+            numPages = map.getNumDataPages();
+            countFreedPages++;
+          }
+        }
+      }
+      if (destructive) {
+        // Latest page is not freed by iterator but by map itself
+        Assert.assertEquals(countFreedPages, numPages - 1);
       }
       Assert.assertEquals(size, valuesSeen.cardinality());
     } finally {
@@ -240,6 +257,16 @@ public abstract class AbstractBytesToBytesMapSuite {
     }
   }
 
+  @Test
+  public void iteratorTest() throws Exception {
+    iteratorTestBase(false);
+  }
+
+  @Test
+  public void destructiveIteratorTest() throws Exception {
+    iteratorTestBase(true);
+  }
+
   @Test
   public void iteratingOverDataPagesWithWastedSpace() throws Exception {
     final int NUM_ENTRIES = 1000 * 1000;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 02458030b00e9d74cc4535ee4557a143718a19fa..efb33530dac865cbdcd49487dd1a4cff213e512b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -154,14 +154,17 @@ public final class UnsafeFixedWidthAggregationMap {
   }
 
   /**
-   * Returns an iterator over the keys and values in this map.
+   * Returns an iterator over the keys and values in this map. This uses destructive iterator of
+   * BytesToBytesMap. So it is illegal to call any other method on this map after `iterator()` has
+   * been called.
    *
    * For efficiency, each call returns the same object.
    */
   public KVIterator<UnsafeRow, UnsafeRow> iterator() {
     return new KVIterator<UnsafeRow, UnsafeRow>() {
 
-      private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator = map.iterator();
+      private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator =
+        map.destructiveIterator();
       private final UnsafeRow key = new UnsafeRow();
       private final UnsafeRow value = new UnsafeRow();
 
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 6c1cf136d9b81006f3f48c8815ab1c50e7f0aaf5..9a65c9d3a404a80587907670cc59e7f381c33e2f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -88,8 +88,11 @@ public final class UnsafeKVExternalSorter {
       final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
         taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements()));
 
-      final int numKeyFields = keySchema.size();
+      // We cannot use the destructive iterator here because we are reusing the existing memory
+      // pages in BytesToBytesMap to hold records during sorting.
+      // The only new memory we are allocating is the pointer/prefix array.
       BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
+      final int numKeyFields = keySchema.size();
       UnsafeRow row = new UnsafeRow();
       while (iter.hasNext()) {
         final BytesToBytesMap.Location loc = iter.next();