diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index dec7fcfa0ddc10d21d97618df9405ca678826a82..e6ddd08e5fa993633e75751dc98229a854e6e1d6 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -34,6 +34,7 @@ import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.unsafe.memory.TaskMemoryManager;
@@ -143,8 +144,7 @@ public final class UnsafeExternalSorter {
     taskContext.addOnCompleteCallback(new AbstractFunction0<BoxedUnit>() {
       @Override
       public BoxedUnit apply() {
-        deleteSpillFiles();
-        freeMemory();
+        cleanupResources();
         return null;
       }
     });
@@ -249,7 +249,7 @@ public final class UnsafeExternalSorter {
    *
    * @return the number of bytes freed.
    */
-  public long freeMemory() {
+  private long freeMemory() {
     updatePeakMemoryUsed();
     long memoryFreed = 0;
     for (MemoryBlock block : allocatedPages) {
@@ -275,44 +275,32 @@ public final class UnsafeExternalSorter {
   /**
    * Deletes any spill files created by this sorter.
    */
-  public void deleteSpillFiles() {
+  private void deleteSpillFiles() {
     for (UnsafeSorterSpillWriter spill : spillWriters) {
       File file = spill.getFile();
       if (file != null && file.exists()) {
         if (!file.delete()) {
           logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
-        };
+        }
       }
     }
   }
 
   /**
-   * Checks whether there is enough space to insert a new record into the sorter.
-   *
-   * @param requiredSpace the required space in the data page, in bytes, including space for storing
-   *                      the record size.
-
-   * @return true if the record can be inserted without requiring more allocations, false otherwise.
+   * Frees this sorter's in-memory data structures and cleans up its spill files.
    */
-  private boolean haveSpaceForRecord(int requiredSpace) {
-    assert(requiredSpace > 0);
-    assert(inMemSorter != null);
-    return (inMemSorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
+  public void cleanupResources() {
+    deleteSpillFiles();
+    freeMemory();
   }
 
   /**
-   * Allocates more memory in order to insert an additional record. This will request additional
-   * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
-   * obtained.
-   *
-   * @param requiredSpace the required space in the data page, in bytes, including space for storing
-   *                      the record size.
+   * Checks whether there is enough space to insert an additional record in to the sort pointer
+   * array and grows the array if additional space is required. If the required space cannot be
+   * obtained, then the in-memory data will be spilled to disk.
    */
-  private void allocateSpaceForRecord(int requiredSpace) throws IOException {
+  private void growPointerArrayIfNecessary() throws IOException {
     assert(inMemSorter != null);
-    // TODO: merge these steps to first calculate total memory requirements for this insert,
-    // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
-    // data page.
     if (!inMemSorter.hasSpaceForAnotherRecord()) {
       logger.debug("Attempting to expand sort pointer array");
       final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
@@ -326,7 +314,20 @@ public final class UnsafeExternalSorter {
         shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
       }
     }
+  }
 
+  /**
+   * Allocates more memory in order to insert an additional record. This will request additional
+   * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+   * obtained.
+   *
+   * @param requiredSpace the required space in the data page, in bytes, including space for storing
+   *                      the record size. This must be less than or equal to the page size (records
+   *                      that exceed the page size are handled via a different code path which uses
+   *                      special overflow pages).
+   */
+  private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
+    assert (requiredSpace <= pageSizeBytes);
     if (requiredSpace > freeSpaceInCurrentPage) {
       logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
         freeSpaceInCurrentPage);
@@ -339,9 +340,7 @@ public final class UnsafeExternalSorter {
       } else {
         final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
         if (memoryAcquired < pageSizeBytes) {
-          if (memoryAcquired > 0) {
-            shuffleMemoryManager.release(memoryAcquired);
-          }
+          shuffleMemoryManager.release(memoryAcquired);
           spill();
           final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
           if (memoryAcquiredAfterSpilling != pageSizeBytes) {
@@ -365,26 +364,59 @@ public final class UnsafeExternalSorter {
       long recordBaseOffset,
       int lengthInBytes,
       long prefix) throws IOException {
+
+    growPointerArrayIfNecessary();
     // Need 4 bytes to store the record length.
     final int totalSpaceRequired = lengthInBytes + 4;
-    if (!haveSpaceForRecord(totalSpaceRequired)) {
-      allocateSpaceForRecord(totalSpaceRequired);
+
+    // --- Figure out where to insert the new record ----------------------------------------------
+
+    final MemoryBlock dataPage;
+    long dataPagePosition;
+    boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
+    if (useOverflowPage) {
+      long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
+      // The record is larger than the page size, so allocate a special overflow page just to hold
+      // that record.
+      final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+      if (memoryGranted != overflowPageSize) {
+        shuffleMemoryManager.release(memoryGranted);
+        spill();
+        final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+        if (memoryGrantedAfterSpill != overflowPageSize) {
+          shuffleMemoryManager.release(memoryGrantedAfterSpill);
+          throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
+        }
+      }
+      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+      allocatedPages.add(overflowPage);
+      dataPage = overflowPage;
+      dataPagePosition = overflowPage.getBaseOffset();
+    } else {
+      // The record is small enough to fit in a regular data page, but the current page might not
+      // have enough space to hold it (or no pages have been allocated yet).
+      acquireNewPageIfNecessary(totalSpaceRequired);
+      dataPage = currentPage;
+      dataPagePosition = currentPagePosition;
+      // Update bookkeeping information
+      freeSpaceInCurrentPage -= totalSpaceRequired;
+      currentPagePosition += totalSpaceRequired;
     }
-    assert(inMemSorter != null);
+    final Object dataPageBaseObject = dataPage.getBaseObject();
+
+    // --- Insert the record ----------------------------------------------------------------------
 
     final long recordAddress =
-      taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
-    final Object dataPageBaseObject = currentPage.getBaseObject();
-    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
-    currentPagePosition += 4;
+      taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
+    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
+    dataPagePosition += 4;
     PlatformDependent.copyMemory(
       recordBaseObject,
       recordBaseOffset,
       dataPageBaseObject,
-      currentPagePosition,
+      dataPagePosition,
       lengthInBytes);
-    currentPagePosition += lengthInBytes;
-    freeSpaceInCurrentPage -= totalSpaceRequired;
+    assert(inMemSorter != null);
     inMemSorter.insertRecord(recordAddress, prefix);
   }
 
@@ -399,33 +431,70 @@ public final class UnsafeExternalSorter {
   public void insertKVRecord(
       Object keyBaseObj, long keyOffset, int keyLen,
       Object valueBaseObj, long valueOffset, int valueLen, long prefix) throws IOException {
+
+    growPointerArrayIfNecessary();
     final int totalSpaceRequired = keyLen + valueLen + 4 + 4;
-    if (!haveSpaceForRecord(totalSpaceRequired)) {
-      allocateSpaceForRecord(totalSpaceRequired);
+
+    // --- Figure out where to insert the new record ----------------------------------------------
+
+    final MemoryBlock dataPage;
+    long dataPagePosition;
+    boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
+    if (useOverflowPage) {
+      long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
+      // The record is larger than the page size, so allocate a special overflow page just to hold
+      // that record.
+      final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+      if (memoryGranted != overflowPageSize) {
+        shuffleMemoryManager.release(memoryGranted);
+        spill();
+        final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+        if (memoryGrantedAfterSpill != overflowPageSize) {
+          shuffleMemoryManager.release(memoryGrantedAfterSpill);
+          throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
+        }
+      }
+      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+      allocatedPages.add(overflowPage);
+      dataPage = overflowPage;
+      dataPagePosition = overflowPage.getBaseOffset();
+    } else {
+      // The record is small enough to fit in a regular data page, but the current page might not
+      // have enough space to hold it (or no pages have been allocated yet).
+      acquireNewPageIfNecessary(totalSpaceRequired);
+      dataPage = currentPage;
+      dataPagePosition = currentPagePosition;
+      // Update bookkeeping information
+      freeSpaceInCurrentPage -= totalSpaceRequired;
+      currentPagePosition += totalSpaceRequired;
     }
-    assert(inMemSorter != null);
+    final Object dataPageBaseObject = dataPage.getBaseObject();
+
+    // --- Insert the record ----------------------------------------------------------------------
 
     final long recordAddress =
-      taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
-    final Object dataPageBaseObject = currentPage.getBaseObject();
-    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, keyLen + valueLen + 4);
-    currentPagePosition += 4;
+      taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
+    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4);
+    dataPagePosition += 4;
 
-    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, keyLen);
-    currentPagePosition += 4;
+    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, keyLen);
+    dataPagePosition += 4;
 
     PlatformDependent.copyMemory(
-      keyBaseObj, keyOffset, dataPageBaseObject, currentPagePosition, keyLen);
-    currentPagePosition += keyLen;
+      keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen);
+    dataPagePosition += keyLen;
 
     PlatformDependent.copyMemory(
-      valueBaseObj, valueOffset, dataPageBaseObject, currentPagePosition, valueLen);
-    currentPagePosition += valueLen;
+      valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen);
 
-    freeSpaceInCurrentPage -= totalSpaceRequired;
+    assert(inMemSorter != null);
     inMemSorter.insertRecord(recordAddress, prefix);
   }
 
+  /**
+   * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()`
+   * after consuming this iterator.
+   */
   public UnsafeSorterIterator getSortedIterator() throws IOException {
     assert(inMemSorter != null);
     final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator();
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
index db9e82759090af957f7151fb3f0dffe93252de98..934b7e03050b6446b545c79835160fd0d9b30944 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
@@ -32,8 +32,8 @@ public class PackedRecordPointerSuite {
   public void heap() {
     final TaskMemoryManager memoryManager =
       new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
-    final MemoryBlock page0 = memoryManager.allocatePage(100);
-    final MemoryBlock page1 = memoryManager.allocatePage(100);
+    final MemoryBlock page0 = memoryManager.allocatePage(128);
+    final MemoryBlock page1 = memoryManager.allocatePage(128);
     final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
       page1.getBaseOffset() + 42);
     PackedRecordPointer packedPointer = new PackedRecordPointer();
@@ -50,8 +50,8 @@ public class PackedRecordPointerSuite {
   public void offHeap() {
     final TaskMemoryManager memoryManager =
       new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
-    final MemoryBlock page0 = memoryManager.allocatePage(100);
-    final MemoryBlock page1 = memoryManager.allocatePage(100);
+    final MemoryBlock page0 = memoryManager.allocatePage(128);
+    final MemoryBlock page1 = memoryManager.allocatePage(128);
     final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
       page1.getBaseOffset() + 42);
     PackedRecordPointer packedPointer = new PackedRecordPointer();
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index c11949d57a0ea5a06db9aeb76195080f3ce8773e..968185bde78abfa33cc20c79a8599acfb8b47134 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -18,8 +18,10 @@
 package org.apache.spark.util.collection.unsafe.sort;
 
 import java.io.File;
+import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
+import java.util.Arrays;
 import java.util.LinkedList;
 import java.util.UUID;
 
@@ -34,6 +36,7 @@ import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.junit.Assert.*;
 import static org.mockito.AdditionalAnswers.returnsSecondArg;
 import static org.mockito.Answers.RETURNS_SMART_NULLS;
@@ -77,12 +80,13 @@ public class UnsafeExternalSorterSuite {
     }
   };
 
+  SparkConf sparkConf;
+  File tempDir;
   ShuffleMemoryManager shuffleMemoryManager;
   @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
   @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
   @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
 
-  File tempDir;
 
   private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "64m");
 
@@ -96,6 +100,7 @@ public class UnsafeExternalSorterSuite {
   @Before
   public void setUp() {
     MockitoAnnotations.initMocks(this);
+    sparkConf = new SparkConf();
     tempDir = new File(Utils.createTempDir$default$1());
     shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE);
     spillFilesCreated.clear();
@@ -155,14 +160,19 @@ public class UnsafeExternalSorterSuite {
   }
 
   private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception {
-    final int[] arr = new int[] { value };
+    final int[] arr = new int[]{ value };
     sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value);
   }
 
-  @Test
-  public void testSortingOnlyByPrefix() throws Exception {
+  private static void insertRecord(
+      UnsafeExternalSorter sorter,
+      int[] record,
+      long prefix) throws IOException {
+    sorter.insertRecord(record, PlatformDependent.INT_ARRAY_OFFSET, record.length * 4, prefix);
+  }
 
-    final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
+  private UnsafeExternalSorter newSorter() throws IOException {
+    return UnsafeExternalSorter.create(
       taskMemoryManager,
       shuffleMemoryManager,
       blockManager,
@@ -171,7 +181,11 @@ public class UnsafeExternalSorterSuite {
       prefixComparator,
       /* initialSize */ 1024,
       pageSizeBytes);
+  }
 
+  @Test
+  public void testSortingOnlyByPrefix() throws Exception {
+    final UnsafeExternalSorter sorter = newSorter();
     insertNumber(sorter, 5);
     insertNumber(sorter, 1);
     insertNumber(sorter, 3);
@@ -186,26 +200,16 @@ public class UnsafeExternalSorterSuite {
       iter.loadNext();
       assertEquals(i, iter.getKeyPrefix());
       assertEquals(4, iter.getRecordLength());
-      // TODO: read rest of value.
+      assertEquals(i, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset()));
     }
 
-    sorter.freeMemory();
+    sorter.cleanupResources();
     assertSpillFilesWereCleanedUp();
   }
 
   @Test
   public void testSortingEmptyArrays() throws Exception {
-
-    final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
-      taskMemoryManager,
-      shuffleMemoryManager,
-      blockManager,
-      taskContext,
-      recordComparator,
-      prefixComparator,
-      /* initialSize */ 1024,
-      pageSizeBytes);
-
+    final UnsafeExternalSorter sorter = newSorter();
     sorter.insertRecord(null, 0, 0, 0);
     sorter.insertRecord(null, 0, 0, 0);
     sorter.spill();
@@ -222,28 +226,89 @@ public class UnsafeExternalSorterSuite {
       assertEquals(0, iter.getRecordLength());
     }
 
-    sorter.freeMemory();
+    sorter.cleanupResources();
     assertSpillFilesWereCleanedUp();
   }
 
   @Test
-  public void testFillingPage() throws Exception {
+  public void spillingOccursInResponseToMemoryPressure() throws Exception {
+    shuffleMemoryManager = new ShuffleMemoryManager(pageSizeBytes * 2);
+    final UnsafeExternalSorter sorter = newSorter();
+    final int numRecords = 100000;
+    for (int i = 0; i <= numRecords; i++) {
+      insertNumber(sorter, numRecords - i);
+    }
+    // Ensure that spill files were created
+    assertThat(tempDir.listFiles().length, greaterThanOrEqualTo(1));
+    // Read back the sorted data:
+    UnsafeSorterIterator iter = sorter.getSortedIterator();
 
-    final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
-      taskMemoryManager,
-      shuffleMemoryManager,
-      blockManager,
-      taskContext,
-      recordComparator,
-      prefixComparator,
-      /* initialSize */ 1024,
-      pageSizeBytes);
+    int i = 0;
+    while (iter.hasNext()) {
+      iter.loadNext();
+      assertEquals(i, iter.getKeyPrefix());
+      assertEquals(4, iter.getRecordLength());
+      assertEquals(i, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset()));
+      i++;
+    }
+    sorter.cleanupResources();
+    assertSpillFilesWereCleanedUp();
+  }
 
+  @Test
+  public void testFillingPage() throws Exception {
+    final UnsafeExternalSorter sorter = newSorter();
     byte[] record = new byte[16];
     while (sorter.getNumberOfAllocatedPages() < 2) {
       sorter.insertRecord(record, PlatformDependent.BYTE_ARRAY_OFFSET, record.length, 0);
     }
-    sorter.freeMemory();
+    sorter.cleanupResources();
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void sortingRecordsThatExceedPageSize() throws Exception {
+    final UnsafeExternalSorter sorter = newSorter();
+    final int[] largeRecord = new int[(int) pageSizeBytes + 16];
+    Arrays.fill(largeRecord, 456);
+    final int[] smallRecord = new int[100];
+    Arrays.fill(smallRecord, 123);
+
+    insertRecord(sorter, largeRecord, 456);
+    sorter.spill();
+    insertRecord(sorter, smallRecord, 123);
+    sorter.spill();
+    insertRecord(sorter, smallRecord, 123);
+    insertRecord(sorter, largeRecord, 456);
+
+    UnsafeSorterIterator iter = sorter.getSortedIterator();
+    // Small record
+    assertTrue(iter.hasNext());
+    iter.loadNext();
+    assertEquals(123, iter.getKeyPrefix());
+    assertEquals(smallRecord.length * 4, iter.getRecordLength());
+    assertEquals(123, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset()));
+    // Small record
+    assertTrue(iter.hasNext());
+    iter.loadNext();
+    assertEquals(123, iter.getKeyPrefix());
+    assertEquals(smallRecord.length * 4, iter.getRecordLength());
+    assertEquals(123, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset()));
+    // Large record
+    assertTrue(iter.hasNext());
+    iter.loadNext();
+    assertEquals(456, iter.getKeyPrefix());
+    assertEquals(largeRecord.length * 4, iter.getRecordLength());
+    assertEquals(456, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset()));
+    // Large record
+    assertTrue(iter.hasNext());
+    iter.loadNext();
+    assertEquals(456, iter.getKeyPrefix());
+    assertEquals(largeRecord.length * 4, iter.getRecordLength());
+    assertEquals(456, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset()));
+
+    assertFalse(iter.hasNext());
+    sorter.cleanupResources();
     assertSpillFilesWereCleanedUp();
   }
 
@@ -289,8 +354,10 @@ public class UnsafeExternalSorterSuite {
       newPeakMemory = sorter.getPeakMemoryUsedBytes();
       assertEquals(previousPeakMemory, newPeakMemory);
     } finally {
-      sorter.freeMemory();
+      sorter.cleanupResources();
+      assertSpillFilesWereCleanedUp();
     }
   }
 
 }
+
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 193906d24790eccf90e4bb5c0eda2306af93ce70..a5ae2b973652787abf5210475837e9b229d6904f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -114,7 +114,7 @@ final class UnsafeExternalRowSorter {
   }
 
   private void cleanupResources() {
-    sorter.freeMemory();
+    sorter.cleanupResources();
   }
 
   @VisibleForTesting
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 312ec8ea0dd9dc7681d9b06b2f1991c168dfe037..86a563df992d0e7aa5ff07b8d0708813432d1003 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -134,6 +134,10 @@ public final class UnsafeKVExternalSorter {
       value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix);
   }
 
+  /**
+   * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()`
+   * after consuming this iterator.
+   */
   public KVSorterIterator sortedIterator() throws IOException {
     try {
       final UnsafeSorterIterator underlying = sorter.getSortedIterator();
@@ -158,8 +162,11 @@ public final class UnsafeKVExternalSorter {
     sorter.closeCurrentPage();
   }
 
-  private void cleanupResources() {
-    sorter.freeMemory();
+  /**
+   * Frees this sorter's in-memory data structures and cleans up its spill files.
+   */
+  public void cleanupResources() {
+    sorter.cleanupResources();
   }
 
   private static final class KVComparator extends RecordComparator {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 601a5a07ad00295ec7f0259e85b5d046ba45d0a7..08156f0e39ce8ad0f8befb1cd592ac61d276fb45 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
 import scala.util.Random
 
 import org.apache.spark._
-import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.{RandomDataGenerator, Row}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, RowOrdering, UnsafeProjection}
 import org.apache.spark.sql.test.TestSQLContext
@@ -46,6 +46,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite {
     testKVSorter(keySchema, valueSchema, spill = i > 3)
   }
 
+
   /**
    * Create a test case using randomly generated data for the given key and value schema.
    *
@@ -60,96 +61,151 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite {
    * If spill is set to true, the sorter will spill probabilistically roughly every 100 records.
    */
   private def testKVSorter(keySchema: StructType, valueSchema: StructType, spill: Boolean): Unit = {
+    // Create the data converters
+    val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema)
+    val vExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema)
+    val kConverter = UnsafeProjection.create(keySchema)
+    val vConverter = UnsafeProjection.create(valueSchema)
+
+    val keyDataGen = RandomDataGenerator.forType(keySchema, nullable = false).get
+    val valueDataGen = RandomDataGenerator.forType(valueSchema, nullable = false).get
+
+    val inputData = Seq.fill(1024) {
+      val k = kConverter(kExternalConverter.apply(keyDataGen.apply()).asInstanceOf[InternalRow])
+      val v = vConverter(vExternalConverter.apply(valueDataGen.apply()).asInstanceOf[InternalRow])
+      (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy())
+    }
 
     val keySchemaStr = keySchema.map(_.dataType.simpleString).mkString("[", ",", "]")
     val valueSchemaStr = valueSchema.map(_.dataType.simpleString).mkString("[", ",", "]")
 
     test(s"kv sorting key schema $keySchemaStr and value schema $valueSchemaStr") {
-      // Calling this make sure we have block manager and everything else setup.
-      TestSQLContext
-
-      val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
-      val shuffleMemMgr = new TestShuffleMemoryManager
-      TaskContext.setTaskContext(new TaskContextImpl(
-        stageId = 0,
-        partitionId = 0,
-        taskAttemptId = 98456,
-        attemptNumber = 0,
-        taskMemoryManager = taskMemMgr,
-        metricsSystem = null,
-        internalAccumulators = Seq.empty))
-
-      // Create the data converters
-      val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema)
-      val vExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema)
-      val kConverter = UnsafeProjection.create(keySchema)
-      val vConverter = UnsafeProjection.create(valueSchema)
-
-      val keyDataGen = RandomDataGenerator.forType(keySchema, nullable = false).get
-      val valueDataGen = RandomDataGenerator.forType(valueSchema, nullable = false).get
-
-      val input = Seq.fill(1024) {
-        val k = kConverter(kExternalConverter.apply(keyDataGen.apply()).asInstanceOf[InternalRow])
-        val v = vConverter(vExternalConverter.apply(valueDataGen.apply()).asInstanceOf[InternalRow])
-        (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy())
-      }
-
-      val sorter = new UnsafeKVExternalSorter(
-        keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, 16 * 1024 * 1024)
-
-      // Insert generated keys and values into the sorter
-      input.foreach { case (k, v) =>
-        sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow])
-        // 1% chance we will spill
-        if (rand.nextDouble() < 0.01 && spill) {
-          shuffleMemMgr.markAsOutOfMemory()
-          sorter.closeCurrentPage()
-        }
-      }
+      testKVSorter(
+        keySchema,
+        valueSchema,
+        inputData,
+        pageSize = 16 * 1024 * 1024,
+        spill
+      )
+    }
+  }
 
-      // Collect the sorted output
-      val out = new scala.collection.mutable.ArrayBuffer[(InternalRow, InternalRow)]
-      val iter = sorter.sortedIterator()
-      while (iter.next()) {
-        out += Tuple2(iter.getKey.copy(), iter.getValue.copy())
+  /**
+   * Create a test case using the given input data for the given key and value schema.
+   *
+   * The approach works as follows:
+   *
+   * - Create input by randomly generating data based on the given schema
+   * - Run [[UnsafeKVExternalSorter]] on the input data
+   * - Collect the output from the sorter, and make sure the keys are sorted in ascending order
+   * - Sort the input by both key and value, and sort the sorter output also by both key and value.
+   *   Compare the sorted input and sorted output together to make sure all the key/values match.
+   *
+   * If spill is set to true, the sorter will spill probabilistically roughly every 100 records.
+   */
+  private def testKVSorter(
+      keySchema: StructType,
+      valueSchema: StructType,
+      inputData: Seq[(InternalRow, InternalRow)],
+      pageSize: Long,
+      spill: Boolean): Unit = {
+    // Calling this make sure we have block manager and everything else setup.
+    TestSQLContext
+
+    val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+    val shuffleMemMgr = new TestShuffleMemoryManager
+    TaskContext.setTaskContext(new TaskContextImpl(
+      stageId = 0,
+      partitionId = 0,
+      taskAttemptId = 98456,
+      attemptNumber = 0,
+      taskMemoryManager = taskMemMgr,
+      metricsSystem = null,
+      internalAccumulators = Seq.empty))
+
+    val sorter = new UnsafeKVExternalSorter(
+      keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, pageSize)
+
+    // Insert the keys and values into the sorter
+    inputData.foreach { case (k, v) =>
+      sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow])
+      // 1% chance we will spill
+      if (rand.nextDouble() < 0.01 && spill) {
+        shuffleMemMgr.markAsOutOfMemory()
+        sorter.closeCurrentPage()
       }
+    }
 
-      val keyOrdering = RowOrdering.forSchema(keySchema.map(_.dataType))
-      val valueOrdering = RowOrdering.forSchema(valueSchema.map(_.dataType))
-      val kvOrdering = new Ordering[(InternalRow, InternalRow)] {
-        override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = {
-          keyOrdering.compare(x._1, y._1) match {
-            case 0 => valueOrdering.compare(x._2, y._2)
-            case cmp => cmp
-          }
+    // Collect the sorted output
+    val out = new scala.collection.mutable.ArrayBuffer[(InternalRow, InternalRow)]
+    val iter = sorter.sortedIterator()
+    while (iter.next()) {
+      out += Tuple2(iter.getKey.copy(), iter.getValue.copy())
+    }
+    sorter.cleanupResources()
+
+    val keyOrdering = RowOrdering.forSchema(keySchema.map(_.dataType))
+    val valueOrdering = RowOrdering.forSchema(valueSchema.map(_.dataType))
+    val kvOrdering = new Ordering[(InternalRow, InternalRow)] {
+      override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = {
+        keyOrdering.compare(x._1, y._1) match {
+          case 0 => valueOrdering.compare(x._2, y._2)
+          case cmp => cmp
         }
       }
+    }
 
-      // Testing to make sure output from the sorter is sorted by key
-      var prevK: InternalRow = null
-      out.zipWithIndex.foreach { case ((k, v), i) =>
-        if (prevK != null) {
-          assert(keyOrdering.compare(prevK, k) <= 0,
-            s"""
-               |key is not in sorted order:
-               |previous key: $prevK
-               |current key : $k
-               """.stripMargin)
-        }
-        prevK = k
+    // Testing to make sure output from the sorter is sorted by key
+    var prevK: InternalRow = null
+    out.zipWithIndex.foreach { case ((k, v), i) =>
+      if (prevK != null) {
+        assert(keyOrdering.compare(prevK, k) <= 0,
+          s"""
+             |key is not in sorted order:
+             |previous key: $prevK
+             |current key : $k
+             """.stripMargin)
       }
+      prevK = k
+    }
 
-      // Testing to make sure the key/value in output matches input
-      assert(out.sorted(kvOrdering) === input.sorted(kvOrdering))
+    // Testing to make sure the key/value in output matches input
+    assert(out.sorted(kvOrdering) === inputData.sorted(kvOrdering))
 
-      // Make sure there is no memory leak
-      val leakedUnsafeMemory: Long = taskMemMgr.cleanUpAllAllocatedMemory
-      if (shuffleMemMgr != null) {
-        val leakedShuffleMemory: Long = shuffleMemMgr.getMemoryConsumptionForThisTask()
-        assert(0L === leakedShuffleMemory)
-      }
-      assert(0 === leakedUnsafeMemory)
-      TaskContext.unset()
+    // Make sure there is no memory leak
+    val leakedUnsafeMemory: Long = taskMemMgr.cleanUpAllAllocatedMemory
+    if (shuffleMemMgr != null) {
+      val leakedShuffleMemory: Long = shuffleMemMgr.getMemoryConsumptionForThisTask()
+      assert(0L === leakedShuffleMemory)
     }
+    assert(0 === leakedUnsafeMemory)
+    TaskContext.unset()
+  }
+
+  test("kv sorting with records that exceed page size") {
+    val pageSize = 128
+
+    val schema = StructType(StructField("b", BinaryType) :: Nil)
+    val externalConverter = CatalystTypeConverters.createToCatalystConverter(schema)
+    val converter = UnsafeProjection.create(schema)
+
+    val rand = new Random()
+    val inputData = Seq.fill(1024) {
+      val kBytes = new Array[Byte](rand.nextInt(pageSize))
+      val vBytes = new Array[Byte](rand.nextInt(pageSize))
+      rand.nextBytes(kBytes)
+      rand.nextBytes(vBytes)
+      val k = converter(externalConverter.apply(Row(kBytes)).asInstanceOf[InternalRow])
+      val v = converter(externalConverter.apply(Row(vBytes)).asInstanceOf[InternalRow])
+      (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy())
+    }
+
+    testKVSorter(
+      schema,
+      schema,
+      inputData,
+      pageSize,
+      spill = true
+    )
   }
 }
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
index bbe83d36cf36ba8010dcae85e0c5f90898647337..6722301df19d16c3dfe705a2baf603eccff804e3 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
@@ -24,6 +24,9 @@ public class HeapMemoryAllocator implements MemoryAllocator {
 
   @Override
   public MemoryBlock allocate(long size) throws OutOfMemoryError {
+    if (size % 8 != 0) {
+      throw new IllegalArgumentException("Size " + size + " was not a multiple of 8");
+    }
     long[] array = new long[(int) (size / 8)];
     return MemoryBlock.fromLongArray(array);
   }
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
index 15898771fef25350071721444d0d06e5befa83d1..62f4459696c28ed2782c4ac5de2af28aa365fc89 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
@@ -26,6 +26,9 @@ public class UnsafeMemoryAllocator implements MemoryAllocator {
 
   @Override
   public MemoryBlock allocate(long size) throws OutOfMemoryError {
+    if (size % 8 != 0) {
+      throw new IllegalArgumentException("Size " + size + " was not a multiple of 8");
+    }
     long address = PlatformDependent.UNSAFE.allocateMemory(size);
     return new MemoryBlock(null, address, size);
   }