diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index f6e0913a7a0b3174d9bcfda6a8e2f7680e1f3ee4..925b60a145886da4ef3f596bee96a57518102688 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -17,10 +17,10 @@
 
 package org.apache.spark.shuffle.unsafe;
 
+import javax.annotation.Nullable;
 import java.io.File;
 import java.io.IOException;
 import java.util.LinkedList;
-import javax.annotation.Nullable;
 
 import scala.Tuple2;
 
@@ -34,8 +34,11 @@ import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.serializer.DummySerializerInstance;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.shuffle.ShuffleMemoryManager;
-import org.apache.spark.storage.*;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.DiskBlockObjectWriter;
+import org.apache.spark.storage.TempShuffleBlockId;
 import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.unsafe.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
@@ -68,7 +71,7 @@ final class UnsafeShuffleExternalSorter {
   private final int pageSizeBytes;
   @VisibleForTesting
   final int maxRecordSizeBytes;
-  private final TaskMemoryManager memoryManager;
+  private final TaskMemoryManager taskMemoryManager;
   private final ShuffleMemoryManager shuffleMemoryManager;
   private final BlockManager blockManager;
   private final TaskContext taskContext;
@@ -91,7 +94,7 @@ final class UnsafeShuffleExternalSorter {
   private long peakMemoryUsedBytes;
 
   // These variables are reset after spilling:
-  @Nullable private UnsafeShuffleInMemorySorter sorter;
+  @Nullable private UnsafeShuffleInMemorySorter inMemSorter;
   @Nullable private MemoryBlock currentPage = null;
   private long currentPagePosition = -1;
   private long freeSpaceInCurrentPage = 0;
@@ -105,7 +108,7 @@ final class UnsafeShuffleExternalSorter {
       int numPartitions,
       SparkConf conf,
       ShuffleWriteMetrics writeMetrics) throws IOException {
-    this.memoryManager = memoryManager;
+    this.taskMemoryManager = memoryManager;
     this.shuffleMemoryManager = shuffleMemoryManager;
     this.blockManager = blockManager;
     this.taskContext = taskContext;
@@ -133,7 +136,7 @@ final class UnsafeShuffleExternalSorter {
       throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
     }
 
-    this.sorter = new UnsafeShuffleInMemorySorter(initialSize);
+    this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize);
   }
 
   /**
@@ -160,7 +163,7 @@ final class UnsafeShuffleExternalSorter {
 
     // This call performs the actual sort.
     final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
-      sorter.getSortedIterator();
+      inMemSorter.getSortedIterator();
 
     // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
     // after SPARK-5581 is fixed.
@@ -206,8 +209,8 @@ final class UnsafeShuffleExternalSorter {
       }
 
       final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
-      final Object recordPage = memoryManager.getPage(recordPointer);
-      final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer);
+      final Object recordPage = taskMemoryManager.getPage(recordPointer);
+      final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer);
       int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage);
       long recordReadPosition = recordOffsetInPage + 4; // skip over record length
       while (dataRemaining > 0) {
@@ -269,9 +272,9 @@ final class UnsafeShuffleExternalSorter {
       spills.size() > 1 ? " times" : " time");
 
     writeSortedFile(false);
-    final long sorterMemoryUsage = sorter.getMemoryUsage();
-    sorter = null;
-    shuffleMemoryManager.release(sorterMemoryUsage);
+    final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage();
+    inMemSorter = null;
+    shuffleMemoryManager.release(inMemSorterMemoryUsage);
     final long spillSize = freeMemory();
     taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
 
@@ -283,7 +286,7 @@ final class UnsafeShuffleExternalSorter {
     for (MemoryBlock page : allocatedPages) {
       totalPageSize += page.size();
     }
-    return ((sorter == null) ? 0 : sorter.getMemoryUsage()) + totalPageSize;
+    return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
   }
 
   private void updatePeakMemoryUsed() {
@@ -305,7 +308,7 @@ final class UnsafeShuffleExternalSorter {
     updatePeakMemoryUsed();
     long memoryFreed = 0;
     for (MemoryBlock block : allocatedPages) {
-      memoryManager.freePage(block);
+      taskMemoryManager.freePage(block);
       shuffleMemoryManager.release(block.size());
       memoryFreed += block.size();
     }
@@ -319,54 +322,53 @@ final class UnsafeShuffleExternalSorter {
   /**
    * Force all memory and spill files to be deleted; called by shuffle error-handling code.
    */
-  public void cleanupAfterError() {
+  public void cleanupResources() {
     freeMemory();
     for (SpillInfo spill : spills) {
       if (spill.file.exists() && !spill.file.delete()) {
         logger.error("Unable to delete spill file {}", spill.file.getPath());
       }
     }
-    if (sorter != null) {
-      shuffleMemoryManager.release(sorter.getMemoryUsage());
-      sorter = null;
+    if (inMemSorter != null) {
+      shuffleMemoryManager.release(inMemSorter.getMemoryUsage());
+      inMemSorter = null;
     }
   }
 
   /**
-   * Checks whether there is enough space to insert a new record into the sorter.
-   *
-   * @param requiredSpace the required space in the data page, in bytes, including space for storing
-   *                      the record size.
-
-   * @return true if the record can be inserted without requiring more allocations, false otherwise.
-   */
-  private boolean haveSpaceForRecord(int requiredSpace) {
-    assert (requiredSpace > 0);
-    return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
-  }
-
-  /**
-   * Allocates more memory in order to insert an additional record. This will request additional
-   * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
-   * obtained.
-   *
-   * @param requiredSpace the required space in the data page, in bytes, including space for storing
-   *                      the record size.
+   * Checks whether there is enough space to insert an additional record in to the sort pointer
+   * array and grows the array if additional space is required. If the required space cannot be
+   * obtained, then the in-memory data will be spilled to disk.
    */
-  private void allocateSpaceForRecord(int requiredSpace) throws IOException {
-    if (!sorter.hasSpaceForAnotherRecord()) {
+  private void growPointerArrayIfNecessary() throws IOException {
+    assert(inMemSorter != null);
+    if (!inMemSorter.hasSpaceForAnotherRecord()) {
       logger.debug("Attempting to expand sort pointer array");
-      final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
+      final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
       final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
       final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
       if (memoryAcquired < memoryToGrowPointerArray) {
         shuffleMemoryManager.release(memoryAcquired);
         spill();
       } else {
-        sorter.expandPointerArray();
+        inMemSorter.expandPointerArray();
         shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
       }
     }
+  }
+  
+  /**
+   * Allocates more memory in order to insert an additional record. This will request additional
+   * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+   * obtained.
+   *
+   * @param requiredSpace the required space in the data page, in bytes, including space for storing
+   *                      the record size. This must be less than or equal to the page size (records
+   *                      that exceed the page size are handled via a different code path which uses
+   *                      special overflow pages).
+   */
+  private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
+    growPointerArrayIfNecessary();
     if (requiredSpace > freeSpaceInCurrentPage) {
       logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
         freeSpaceInCurrentPage);
@@ -387,7 +389,7 @@ final class UnsafeShuffleExternalSorter {
             throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
           }
         }
-        currentPage = memoryManager.allocatePage(pageSizeBytes);
+        currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
         currentPagePosition = currentPage.getBaseOffset();
         freeSpaceInCurrentPage = pageSizeBytes;
         allocatedPages.add(currentPage);
@@ -403,27 +405,58 @@ final class UnsafeShuffleExternalSorter {
       long recordBaseOffset,
       int lengthInBytes,
       int partitionId) throws IOException {
+
+    growPointerArrayIfNecessary();
     // Need 4 bytes to store the record length.
     final int totalSpaceRequired = lengthInBytes + 4;
-    if (!haveSpaceForRecord(totalSpaceRequired)) {
-      allocateSpaceForRecord(totalSpaceRequired);
+
+    // --- Figure out where to insert the new record ----------------------------------------------
+
+    final MemoryBlock dataPage;
+    long dataPagePosition;
+    boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
+    if (useOverflowPage) {
+      long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
+      // The record is larger than the page size, so allocate a special overflow page just to hold
+      // that record.
+      final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+      if (memoryGranted != overflowPageSize) {
+        shuffleMemoryManager.release(memoryGranted);
+        spill();
+        final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+        if (memoryGrantedAfterSpill != overflowPageSize) {
+          shuffleMemoryManager.release(memoryGrantedAfterSpill);
+          throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
+        }
+      }
+      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+      allocatedPages.add(overflowPage);
+      dataPage = overflowPage;
+      dataPagePosition = overflowPage.getBaseOffset();
+    } else {
+      // The record is small enough to fit in a regular data page, but the current page might not
+      // have enough space to hold it (or no pages have been allocated yet).
+      acquireNewPageIfNecessary(totalSpaceRequired);
+      dataPage = currentPage;
+      dataPagePosition = currentPagePosition;
+      // Update bookkeeping information
+      freeSpaceInCurrentPage -= totalSpaceRequired;
+      currentPagePosition += totalSpaceRequired;
     }
+    final Object dataPageBaseObject = dataPage.getBaseObject();
 
     final long recordAddress =
-      memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
-    final Object dataPageBaseObject = currentPage.getBaseObject();
-    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
-    currentPagePosition += 4;
-    freeSpaceInCurrentPage -= 4;
+      taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
+    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
+    dataPagePosition += 4;
     PlatformDependent.copyMemory(
       recordBaseObject,
       recordBaseOffset,
       dataPageBaseObject,
-      currentPagePosition,
+      dataPagePosition,
       lengthInBytes);
-    currentPagePosition += lengthInBytes;
-    freeSpaceInCurrentPage -= lengthInBytes;
-    sorter.insertRecord(recordAddress, partitionId);
+    assert(inMemSorter != null);
+    inMemSorter.insertRecord(recordAddress, partitionId);
   }
 
   /**
@@ -435,14 +468,14 @@ final class UnsafeShuffleExternalSorter {
    */
   public SpillInfo[] closeAndGetSpills() throws IOException {
     try {
-      if (sorter != null) {
+      if (inMemSorter != null) {
         // Do not count the final file towards the spill count.
         writeSortedFile(true);
         freeMemory();
       }
       return spills.toArray(new SpillInfo[spills.size()]);
     } catch (IOException e) {
-      cleanupAfterError();
+      cleanupResources();
       throw e;
     }
   }
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
index 6e2eeb37c86f145729dc29bb652438f35c451c46..02084f9122e00fb99ad6e2bd35d373a18cca46fe 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -17,17 +17,17 @@
 
 package org.apache.spark.shuffle.unsafe;
 
+import javax.annotation.Nullable;
 import java.io.*;
 import java.nio.channels.FileChannel;
 import java.util.Iterator;
-import javax.annotation.Nullable;
 
 import scala.Option;
 import scala.Product2;
 import scala.collection.JavaConversions;
+import scala.collection.immutable.Map;
 import scala.reflect.ClassTag;
 import scala.reflect.ClassTag$;
-import scala.collection.immutable.Map;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.io.ByteStreams;
@@ -38,10 +38,10 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.spark.*;
 import org.apache.spark.annotation.Private;
+import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.io.CompressionCodec;
 import org.apache.spark.io.CompressionCodec$;
 import org.apache.spark.io.LZFCompressionCodec;
-import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.network.util.LimitedInputStream;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.scheduler.MapStatus$;
@@ -178,7 +178,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     } finally {
       if (sorter != null) {
         try {
-          sorter.cleanupAfterError();
+          sorter.cleanupResources();
         } catch (Exception e) {
           // Only throw this error if we won't be masking another
           // error.
@@ -482,7 +482,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
       if (sorter != null) {
         // If sorter is non-null, then this implies that we called stop() in response to an error,
         // so we need to clean up memory and spill files created by the sorter
-        sorter.cleanupAfterError();
+        sorter.cleanupResources();
       }
     }
   }
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
index c68354ba49a46e104b50d9dbd0b2c036ead04859..94650be536b5fa7a9fd8518c547b8030f14fffe6 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -475,62 +475,22 @@ public class UnsafeShuffleWriterSuite {
 
   @Test
   public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
-    // Use a custom serializer so that we have exact control over the size of serialized data.
-    final Serializer byteArraySerializer = new Serializer() {
-      @Override
-      public SerializerInstance newInstance() {
-        return new SerializerInstance() {
-          @Override
-          public SerializationStream serializeStream(final OutputStream s) {
-            return new SerializationStream() {
-              @Override
-              public void flush() { }
-
-              @Override
-              public <T> SerializationStream writeObject(T t, ClassTag<T> ev1) {
-                byte[] bytes = (byte[]) t;
-                try {
-                  s.write(bytes);
-                } catch (IOException e) {
-                  throw new RuntimeException(e);
-                }
-                return this;
-              }
-
-              @Override
-              public void close() { }
-            };
-          }
-          public <T> ByteBuffer serialize(T t, ClassTag<T> ev1) { return null; }
-          public DeserializationStream deserializeStream(InputStream s) { return null; }
-          public <T> T deserialize(ByteBuffer b, ClassLoader l, ClassTag<T> ev1) { return null; }
-          public <T> T deserialize(ByteBuffer bytes, ClassTag<T> ev1) { return null; }
-        };
-      }
-    };
-    when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(byteArraySerializer));
     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
-    // Insert a record and force a spill so that there's something to clean up:
-    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[1], new byte[1]));
-    writer.forceSorterToSpill();
+    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+    dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(new byte[1])));
     // We should be able to write a record that's right _at_ the max record size
     final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()];
     new Random(42).nextBytes(atMaxRecordSize);
-    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[0], atMaxRecordSize));
-    writer.forceSorterToSpill();
-    // Inserting a record that's larger than the max record size should fail:
+    dataToWrite.add(new Tuple2<Object, Object>(2, ByteBuffer.wrap(atMaxRecordSize)));
+    // Inserting a record that's larger than the max record size
     final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1];
     new Random(42).nextBytes(exceedsMaxRecordSize);
-    Product2<Object, Object> hugeRecord =
-      new Tuple2<Object, Object>(new byte[0], exceedsMaxRecordSize);
-    try {
-      // Here, we write through the public `write()` interface instead of the test-only
-      // `insertRecordIntoSorter` interface:
-      writer.write(Collections.singletonList(hugeRecord).iterator());
-      fail("Expected exception to be thrown");
-    } catch (IOException e) {
-      // Pass
-    }
+    dataToWrite.add(new Tuple2<Object, Object>(3, ByteBuffer.wrap(exceedsMaxRecordSize)));
+    writer.write(dataToWrite.iterator());
+    writer.stop(true);
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
     assertSpillFilesWereCleanedUp();
   }