From f6d06adf05afa9c5386dc2396c94e7a98730289f Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@databricks.com>
Date: Thu, 22 Oct 2015 09:46:30 -0700
Subject: [PATCH] [SPARK-10708] Consolidate sort shuffle implementations

There's a lot of duplication between SortShuffleManager and UnsafeShuffleManager. Given that these now provide the same set of functionality, now that UnsafeShuffleManager supports large records, I think that we should replace SortShuffleManager's serialized shuffle implementation with UnsafeShuffleManager's and should merge the two managers together.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #8829 from JoshRosen/consolidate-sort-shuffle-implementations.
---
 .../sort/BypassMergeSortShuffleWriter.java    | 106 +++++--
 .../{unsafe => sort}/PackedRecordPointer.java |   2 +-
 .../ShuffleExternalSorter.java}               |  28 +-
 .../ShuffleInMemorySorter.java}               |  16 +-
 .../ShuffleSortDataFormat.java}               |   8 +-
 .../shuffle/sort/SortShuffleFileWriter.java   |  53 ----
 .../shuffle/{unsafe => sort}/SpillInfo.java   |   4 +-
 .../{unsafe => sort}/UnsafeShuffleWriter.java |  12 +-
 .../scala/org/apache/spark/SparkEnv.scala     |   2 +-
 .../shuffle/sort/SortShuffleManager.scala     | 175 +++++++++--
 .../shuffle/sort/SortShuffleWriter.scala      |  28 +-
 .../shuffle/unsafe/UnsafeShuffleManager.scala | 202 -------------
 .../spark/util/collection/ChainedBuffer.scala | 146 ----------
 .../util/collection/ExternalSorter.scala      |  35 +--
 .../PartitionedSerializedPairBuffer.scala     | 273 ------------------
 .../PackedRecordPointerSuite.java             |   5 +-
 .../ShuffleInMemorySorterSuite.java}          |  16 +-
 .../UnsafeShuffleWriterSuite.java             |  10 +-
 .../org/apache/spark/SortShuffleSuite.scala   |  65 +++++
 .../spark/scheduler/DAGSchedulerSuite.scala   |   6 +-
 .../BypassMergeSortShuffleWriterSuite.scala   |  64 ++--
 .../SortShuffleManagerSuite.scala}            |  30 +-
 .../shuffle/sort/SortShuffleWriterSuite.scala |  45 ---
 .../shuffle/unsafe/UnsafeShuffleSuite.scala   | 102 -------
 .../util/collection/ChainedBufferSuite.scala  | 144 ---------
 ...PartitionedSerializedPairBufferSuite.scala | 148 ----------
 docs/configuration.md                         |   7 +-
 project/MimaExcludes.scala                    |   9 +-
 .../apache/spark/sql/execution/Exchange.scala |  23 +-
 .../execution/UnsafeRowSerializerSuite.scala  |   9 +-
 30 files changed, 456 insertions(+), 1317 deletions(-)
 rename core/src/main/java/org/apache/spark/shuffle/{unsafe => sort}/PackedRecordPointer.java (98%)
 rename core/src/main/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleExternalSorter.java => sort/ShuffleExternalSorter.java} (95%)
 rename core/src/main/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleInMemorySorter.java => sort/ShuffleInMemorySorter.java} (88%)
 rename core/src/main/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleSortDataFormat.java => sort/ShuffleSortDataFormat.java} (86%)
 delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
 rename core/src/main/java/org/apache/spark/shuffle/{unsafe => sort}/SpillInfo.java (90%)
 rename core/src/main/java/org/apache/spark/shuffle/{unsafe => sort}/UnsafeShuffleWriter.java (98%)
 delete mode 100644 core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
 delete mode 100644 core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
 delete mode 100644 core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
 rename core/src/test/java/org/apache/spark/shuffle/{unsafe => sort}/PackedRecordPointerSuite.java (96%)
 rename core/src/test/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleInMemorySorterSuite.java => sort/ShuffleInMemorySorterSuite.java} (87%)
 rename core/src/test/java/org/apache/spark/shuffle/{unsafe => sort}/UnsafeShuffleWriterSuite.java (98%)
 rename core/src/test/scala/org/apache/spark/shuffle/{unsafe/UnsafeShuffleManagerSuite.scala => sort/SortShuffleManagerSuite.scala} (80%)
 delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
 delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
 delete mode 100644 core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
 delete mode 100644 core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala

diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index f5d80bbcf3..ee82d67993 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -21,21 +21,30 @@ import java.io.File;
 import java.io.FileInputStream;
 import java.io.FileOutputStream;
 import java.io.IOException;
+import javax.annotation.Nullable;
 
+import scala.None$;
+import scala.Option;
 import scala.Product2;
 import scala.Tuple2;
 import scala.collection.Iterator;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.io.Closeables;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.Partitioner;
+import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapStatus$;
 import org.apache.spark.serializer.Serializer;
 import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleWriter;
 import org.apache.spark.storage.*;
 import org.apache.spark.util.Utils;
 
@@ -62,7 +71,7 @@ import org.apache.spark.util.Utils;
  * <p>
  * There have been proposals to completely remove this code path; see SPARK-6026 for details.
  */
-final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<K, V> {
+final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
   private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);
 
@@ -72,31 +81,52 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
   private final BlockManager blockManager;
   private final Partitioner partitioner;
   private final ShuffleWriteMetrics writeMetrics;
+  private final int shuffleId;
+  private final int mapId;
   private final Serializer serializer;
+  private final IndexShuffleBlockResolver shuffleBlockResolver;
 
   /** Array of file writers, one for each partition */
   private DiskBlockObjectWriter[] partitionWriters;
+  @Nullable private MapStatus mapStatus;
+  private long[] partitionLengths;
+
+  /**
+   * Are we in the process of stopping? Because map tasks can call stop() with success = true
+   * and then call stop() with success = false if they get an exception, we want to make sure
+   * we don't try deleting files, etc twice.
+   */
+  private boolean stopping = false;
 
   public BypassMergeSortShuffleWriter(
-      SparkConf conf,
       BlockManager blockManager,
-      Partitioner partitioner,
-      ShuffleWriteMetrics writeMetrics,
-      Serializer serializer) {
+      IndexShuffleBlockResolver shuffleBlockResolver,
+      BypassMergeSortShuffleHandle<K, V> handle,
+      int mapId,
+      TaskContext taskContext,
+      SparkConf conf) {
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
     this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
     this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
-    this.numPartitions = partitioner.numPartitions();
     this.blockManager = blockManager;
-    this.partitioner = partitioner;
-    this.writeMetrics = writeMetrics;
-    this.serializer = serializer;
+    final ShuffleDependency<K, V, V> dep = handle.dependency();
+    this.mapId = mapId;
+    this.shuffleId = dep.shuffleId();
+    this.partitioner = dep.partitioner();
+    this.numPartitions = partitioner.numPartitions();
+    this.writeMetrics = new ShuffleWriteMetrics();
+    taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
+    this.serializer = Serializer.getSerializer(dep.serializer());
+    this.shuffleBlockResolver = shuffleBlockResolver;
   }
 
   @Override
-  public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
+  public void write(Iterator<Product2<K, V>> records) throws IOException {
     assert (partitionWriters == null);
     if (!records.hasNext()) {
+      partitionLengths = new long[numPartitions];
+      shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+      mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
       return;
     }
     final SerializerInstance serInstance = serializer.newInstance();
@@ -124,13 +154,24 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
     for (DiskBlockObjectWriter writer : partitionWriters) {
       writer.commitAndClose();
     }
+
+    partitionLengths =
+      writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId));
+    shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+    mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
   }
 
-  @Override
-  public long[] writePartitionedFile(
-      BlockId blockId,
-      TaskContext context,
-      File outputFile) throws IOException {
+  @VisibleForTesting
+  long[] getPartitionLengths() {
+    return partitionLengths;
+  }
+
+  /**
+   * Concatenate all of the per-partition files into a single combined file.
+   *
+   * @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
+   */
+  private long[] writePartitionedFile(File outputFile) throws IOException {
     // Track location of the partition starts in the output file
     final long[] lengths = new long[numPartitions];
     if (partitionWriters == null) {
@@ -165,18 +206,33 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
   }
 
   @Override
-  public void stop() throws IOException {
-    if (partitionWriters != null) {
-      try {
-        for (DiskBlockObjectWriter writer : partitionWriters) {
-          // This method explicitly does _not_ throw exceptions:
-          File file = writer.revertPartialWritesAndClose();
-          if (!file.delete()) {
-            logger.error("Error while deleting file {}", file.getAbsolutePath());
+  public Option<MapStatus> stop(boolean success) {
+    if (stopping) {
+      return None$.empty();
+    } else {
+      stopping = true;
+      if (success) {
+        if (mapStatus == null) {
+          throw new IllegalStateException("Cannot call stop(true) without having called write()");
+        }
+        return Option.apply(mapStatus);
+      } else {
+        // The map task failed, so delete our output data.
+        if (partitionWriters != null) {
+          try {
+            for (DiskBlockObjectWriter writer : partitionWriters) {
+              // This method explicitly does _not_ throw exceptions:
+              File file = writer.revertPartialWritesAndClose();
+              if (!file.delete()) {
+                logger.error("Error while deleting file {}", file.getAbsolutePath());
+              }
+            }
+          } finally {
+            partitionWriters = null;
           }
         }
-      } finally {
-        partitionWriters = null;
+        shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
+        return None$.empty();
       }
     }
   }
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
similarity index 98%
rename from core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
rename to core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
index 4ee6a82c04..c11711966f 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
 
 /**
  * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
similarity index 95%
rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index e73ba39468..85fdaa8115 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
 
 import javax.annotation.Nullable;
 import java.io.File;
@@ -48,7 +48,7 @@ import org.apache.spark.util.Utils;
  * <p>
  * Incoming records are appended to data pages. When all records have been inserted (or when the
  * current thread's shuffle memory limit is reached), the in-memory records are sorted according to
- * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then
+ * their partition ids (using a {@link ShuffleInMemorySorter}). The sorted records are then
  * written to a single output file (or multiple files, if we've spilled). The format of the output
  * files is the same as the format of the final output file written by
  * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
@@ -59,9 +59,9 @@ import org.apache.spark.util.Utils;
  * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
  * specialized merge procedure that avoids extra serialization/deserialization.
  */
-final class UnsafeShuffleExternalSorter {
+final class ShuffleExternalSorter {
 
-  private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
+  private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);
 
   @VisibleForTesting
   static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
@@ -76,6 +76,10 @@ final class UnsafeShuffleExternalSorter {
   private final BlockManager blockManager;
   private final TaskContext taskContext;
   private final ShuffleWriteMetrics writeMetrics;
+  private long numRecordsInsertedSinceLastSpill = 0;
+
+  /** Force this sorter to spill when there are this many elements in memory. For testing only */
+  private final long numElementsForSpillThreshold;
 
   /** The buffer size to use when writing spills using DiskBlockObjectWriter */
   private final int fileBufferSizeBytes;
@@ -94,12 +98,12 @@ final class UnsafeShuffleExternalSorter {
   private long peakMemoryUsedBytes;
 
   // These variables are reset after spilling:
-  @Nullable private UnsafeShuffleInMemorySorter inMemSorter;
+  @Nullable private ShuffleInMemorySorter inMemSorter;
   @Nullable private MemoryBlock currentPage = null;
   private long currentPagePosition = -1;
   private long freeSpaceInCurrentPage = 0;
 
-  public UnsafeShuffleExternalSorter(
+  public ShuffleExternalSorter(
       TaskMemoryManager memoryManager,
       ShuffleMemoryManager shuffleMemoryManager,
       BlockManager blockManager,
@@ -117,6 +121,8 @@ final class UnsafeShuffleExternalSorter {
     this.numPartitions = numPartitions;
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
     this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+    this.numElementsForSpillThreshold =
+      conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
     this.pageSizeBytes = (int) Math.min(
       PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes());
     this.maxRecordSizeBytes = pageSizeBytes - 4;
@@ -140,7 +146,8 @@ final class UnsafeShuffleExternalSorter {
       throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
     }
 
-    this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize);
+    this.inMemSorter = new ShuffleInMemorySorter(initialSize);
+    numRecordsInsertedSinceLastSpill = 0;
   }
 
   /**
@@ -166,7 +173,7 @@ final class UnsafeShuffleExternalSorter {
     }
 
     // This call performs the actual sort.
-    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
+    final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
       inMemSorter.getSortedIterator();
 
     // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
@@ -406,6 +413,10 @@ final class UnsafeShuffleExternalSorter {
       int lengthInBytes,
       int partitionId) throws IOException {
 
+    if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) {
+      spill();
+    }
+
     growPointerArrayIfNecessary();
     // Need 4 bytes to store the record length.
     final int totalSpaceRequired = lengthInBytes + 4;
@@ -453,6 +464,7 @@ final class UnsafeShuffleExternalSorter {
       recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
     assert(inMemSorter != null);
     inMemSorter.insertRecord(recordAddress, partitionId);
+    numRecordsInsertedSinceLastSpill += 1;
   }
 
   /**
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
similarity index 88%
rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index 5bab501da9..a8dee6c610 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -15,13 +15,13 @@
  * limitations under the License.
  */
 
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
 
 import java.util.Comparator;
 
 import org.apache.spark.util.collection.Sorter;
 
-final class UnsafeShuffleInMemorySorter {
+final class ShuffleInMemorySorter {
 
   private final Sorter<PackedRecordPointer, long[]> sorter;
   private static final class SortComparator implements Comparator<PackedRecordPointer> {
@@ -44,10 +44,10 @@ final class UnsafeShuffleInMemorySorter {
    */
   private int pointerArrayInsertPosition = 0;
 
-  public UnsafeShuffleInMemorySorter(int initialSize) {
+  public ShuffleInMemorySorter(int initialSize) {
     assert (initialSize > 0);
     this.pointerArray = new long[initialSize];
-    this.sorter = new Sorter<PackedRecordPointer, long[]>(UnsafeShuffleSortDataFormat.INSTANCE);
+    this.sorter = new Sorter<PackedRecordPointer, long[]>(ShuffleSortDataFormat.INSTANCE);
   }
 
   public void expandPointerArray() {
@@ -92,14 +92,14 @@ final class UnsafeShuffleInMemorySorter {
   /**
    * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
    */
-  public static final class UnsafeShuffleSorterIterator {
+  public static final class ShuffleSorterIterator {
 
     private final long[] pointerArray;
     private final int numRecords;
     final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
     private int position = 0;
 
-    public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) {
+    public ShuffleSorterIterator(int numRecords, long[] pointerArray) {
       this.numRecords = numRecords;
       this.pointerArray = pointerArray;
     }
@@ -117,8 +117,8 @@ final class UnsafeShuffleInMemorySorter {
   /**
    * Return an iterator over record pointers in sorted order.
    */
-  public UnsafeShuffleSorterIterator getSortedIterator() {
+  public ShuffleSorterIterator getSortedIterator() {
     sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
-    return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
+    return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
   }
 }
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
similarity index 86%
rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
index a66d74ee44..8a1e5aec6f 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
@@ -15,15 +15,15 @@
  * limitations under the License.
  */
 
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
 
 import org.apache.spark.util.collection.SortDataFormat;
 
-final class UnsafeShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
+final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
 
-  public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat();
+  public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat();
 
-  private UnsafeShuffleSortDataFormat() { }
+  private ShuffleSortDataFormat() { }
 
   @Override
   public PackedRecordPointer getKey(long[] data, int pos) {
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
deleted file mode 100644
index 656ea0401a..0000000000
--- a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.sort;
-
-import java.io.File;
-import java.io.IOException;
-
-import scala.Product2;
-import scala.collection.Iterator;
-
-import org.apache.spark.annotation.Private;
-import org.apache.spark.TaskContext;
-import org.apache.spark.storage.BlockId;
-
-/**
- * Interface for objects that {@link SortShuffleWriter} uses to write its output files.
- */
-@Private
-public interface SortShuffleFileWriter<K, V> {
-
-  void insertAll(Iterator<Product2<K, V>> records) throws IOException;
-
-  /**
-   * Write all the data added into this shuffle sorter into a file in the disk store. This is
-   * called by the SortShuffleWriter and can go through an efficient path of just concatenating
-   * binary files if we decided to avoid merge-sorting.
-   *
-   * @param blockId block ID to write to. The index file will be blockId.name + ".index".
-   * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
-   * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
-   */
-  long[] writePartitionedFile(
-      BlockId blockId,
-      TaskContext context,
-      File outputFile) throws IOException;
-
-  void stop() throws IOException;
-}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
similarity index 90%
rename from core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
rename to core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
index 7bac0dc0bb..df9f7b7abe 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
@@ -15,14 +15,14 @@
  * limitations under the License.
  */
 
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
 
 import java.io.File;
 
 import org.apache.spark.storage.TempShuffleBlockId;
 
 /**
- * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}.
+ * Metadata for a block of data written by {@link ShuffleExternalSorter}.
  */
 final class SpillInfo {
   final long[] partitionLengths;
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
similarity index 98%
rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
rename to core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index fdb309e365..e8f050cb2d 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
 
 import javax.annotation.Nullable;
 import java.io.*;
@@ -80,7 +80,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
   private final boolean transferToEnabled;
 
   @Nullable private MapStatus mapStatus;
-  @Nullable private UnsafeShuffleExternalSorter sorter;
+  @Nullable private ShuffleExternalSorter sorter;
   private long peakMemoryUsedBytes = 0;
 
   /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
@@ -104,15 +104,15 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
       IndexShuffleBlockResolver shuffleBlockResolver,
       TaskMemoryManager memoryManager,
       ShuffleMemoryManager shuffleMemoryManager,
-      UnsafeShuffleHandle<K, V> handle,
+      SerializedShuffleHandle<K, V> handle,
       int mapId,
       TaskContext taskContext,
       SparkConf sparkConf) throws IOException {
     final int numPartitions = handle.dependency().partitioner().numPartitions();
-    if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) {
+    if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
       throw new IllegalArgumentException(
         "UnsafeShuffleWriter can only be used for shuffles with at most " +
-          UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions");
+          SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions");
     }
     this.blockManager = blockManager;
     this.shuffleBlockResolver = shuffleBlockResolver;
@@ -195,7 +195,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
   private void open() throws IOException {
     assert (sorter == null);
-    sorter = new UnsafeShuffleExternalSorter(
+    sorter = new ShuffleExternalSorter(
       memoryManager,
       shuffleMemoryManager,
       blockManager,
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index c329983451..704158bfc7 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -330,7 +330,7 @@ object SparkEnv extends Logging {
     val shortShuffleMgrNames = Map(
       "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
       "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager",
-      "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager")
+      "tungsten-sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
     val shuffleMgrName = conf.get("spark.shuffle.manager", "sort")
     val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
     val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 9df4e55166..1105167d39 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -19,9 +19,53 @@ package org.apache.spark.shuffle.sort
 
 import java.util.concurrent.ConcurrentHashMap
 
-import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency}
+import org.apache.spark._
+import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle._
 
+/**
+ * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
+ * written to a single map output file. Reducers fetch contiguous regions of this file in order to
+ * read their portion of the map output. In cases where the map output data is too large to fit in
+ * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
+ * to produce the final output file.
+ *
+ * Sort-based shuffle has two different write paths for producing its map output files:
+ *
+ *  - Serialized sorting: used when all three of the following conditions hold:
+ *    1. The shuffle dependency specifies no aggregation or output ordering.
+ *    2. The shuffle serializer supports relocation of serialized values (this is currently
+ *       supported by KryoSerializer and Spark SQL's custom serializers).
+ *    3. The shuffle produces fewer than 16777216 output partitions.
+ *  - Deserialized sorting: used to handle all other cases.
+ *
+ * -----------------------
+ * Serialized sorting mode
+ * -----------------------
+ *
+ * In the serialized sorting mode, incoming records are serialized as soon as they are passed to the
+ * shuffle writer and are buffered in a serialized form during sorting. This write path implements
+ * several optimizations:
+ *
+ *  - Its sort operates on serialized binary data rather than Java objects, which reduces memory
+ *    consumption and GC overheads. This optimization requires the record serializer to have certain
+ *    properties to allow serialized records to be re-ordered without requiring deserialization.
+ *    See SPARK-4550, where this optimization was first proposed and implemented, for more details.
+ *
+ *  - It uses a specialized cache-efficient sorter ([[ShuffleExternalSorter]]) that sorts
+ *    arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
+ *    record in the sorting array, this fits more of the array into cache.
+ *
+ *  - The spill merging procedure operates on blocks of serialized records that belong to the same
+ *    partition and does not need to deserialize records during the merge.
+ *
+ *  - When the spill compression codec supports concatenation of compressed data, the spill merge
+ *    simply concatenates the serialized and compressed spill partitions to produce the final output
+ *    partition.  This allows efficient data copying methods, like NIO's `transferTo`, to be used
+ *    and avoids the need to allocate decompression or copying buffers during the merge.
+ *
+ * For more details on these optimizations, see SPARK-7081.
+ */
 private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
 
   if (!conf.getBoolean("spark.shuffle.spill", true)) {
@@ -30,8 +74,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
         " Shuffle will continue to spill to disk when necessary.")
   }
 
-  private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf)
-  private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]()
+  /**
+   * A mapping from shuffle ids to the number of mappers producing output for those shuffles.
+   */
+  private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]()
+
+  override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
 
   /**
    * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
@@ -40,7 +88,22 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
       shuffleId: Int,
       numMaps: Int,
       dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
-    new BaseShuffleHandle(shuffleId, numMaps, dependency)
+    if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) {
+      // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
+      // need map-side aggregation, then write numPartitions files directly and just concatenate
+      // them at the end. This avoids doing serialization and deserialization twice to merge
+      // together the spilled files, which would happen with the normal code path. The downside is
+      // having multiple files open at a time and thus more memory allocated to buffers.
+      new BypassMergeSortShuffleHandle[K, V](
+        shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+    } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
+      // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
+      new SerializedShuffleHandle[K, V](
+        shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+    } else {
+      // Otherwise, buffer map outputs in a deserialized form:
+      new BaseShuffleHandle(shuffleId, numMaps, dependency)
+    }
   }
 
   /**
@@ -52,38 +115,114 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
       startPartition: Int,
       endPartition: Int,
       context: TaskContext): ShuffleReader[K, C] = {
-    // We currently use the same block store shuffle fetcher as the hash-based shuffle.
     new BlockStoreShuffleReader(
       handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
   }
 
   /** Get a writer for a given partition. Called on executors by map tasks. */
-  override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
-      : ShuffleWriter[K, V] = {
-    val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]]
-    shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps)
-    new SortShuffleWriter(
-      shuffleBlockResolver, baseShuffleHandle, mapId, context)
+  override def getWriter[K, V](
+      handle: ShuffleHandle,
+      mapId: Int,
+      context: TaskContext): ShuffleWriter[K, V] = {
+    numMapsForShuffle.putIfAbsent(
+      handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
+    val env = SparkEnv.get
+    handle match {
+      case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
+        new UnsafeShuffleWriter(
+          env.blockManager,
+          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+          context.taskMemoryManager(),
+          env.shuffleMemoryManager,
+          unsafeShuffleHandle,
+          mapId,
+          context,
+          env.conf)
+      case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
+        new BypassMergeSortShuffleWriter(
+          env.blockManager,
+          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+          bypassMergeSortHandle,
+          mapId,
+          context,
+          env.conf)
+      case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
+        new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
+    }
   }
 
   /** Remove a shuffle's metadata from the ShuffleManager. */
   override def unregisterShuffle(shuffleId: Int): Boolean = {
-    if (shuffleMapNumber.containsKey(shuffleId)) {
-      val numMaps = shuffleMapNumber.remove(shuffleId)
-      (0 until numMaps).map{ mapId =>
+    Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps =>
+      (0 until numMaps).foreach { mapId =>
         shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
       }
     }
     true
   }
 
-  override val shuffleBlockResolver: IndexShuffleBlockResolver = {
-    indexShuffleBlockResolver
-  }
-
   /** Shut down this ShuffleManager. */
   override def stop(): Unit = {
     shuffleBlockResolver.stop()
   }
 }
 
+
+private[spark] object SortShuffleManager extends Logging {
+
+  /**
+   * The maximum number of shuffle output partitions that SortShuffleManager supports when
+   * buffering map outputs in a serialized form. This is an extreme defensive programming measure,
+   * since it's extremely unlikely that a single shuffle produces over 16 million output partitions.
+   * */
+  val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE =
+    PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
+
+  /**
+   * Helper method for determining whether a shuffle should use an optimized serialized shuffle
+   * path or whether it should fall back to the original path that operates on deserialized objects.
+   */
+  def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {
+    val shufId = dependency.shuffleId
+    val numPartitions = dependency.partitioner.numPartitions
+    val serializer = Serializer.getSerializer(dependency.serializer)
+    if (!serializer.supportsRelocationOfSerializedObjects) {
+      log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
+        s"${serializer.getClass.getName}, does not support object relocation")
+      false
+    } else if (dependency.aggregator.isDefined) {
+      log.debug(
+        s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined")
+      false
+    } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
+      log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " +
+        s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions")
+      false
+    } else {
+      log.debug(s"Can use serialized shuffle for shuffle $shufId")
+      true
+    }
+  }
+}
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the
+ * serialized shuffle.
+ */
+private[spark] class SerializedShuffleHandle[K, V](
+  shuffleId: Int,
+  numMaps: Int,
+  dependency: ShuffleDependency[K, V, V])
+  extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+}
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the
+ * bypass merge sort shuffle path.
+ */
+private[spark] class BypassMergeSortShuffleHandle[K, V](
+  shuffleId: Int,
+  numMaps: Int,
+  dependency: ShuffleDependency[K, V, V])
+  extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 5865e7640c..bbd9c1ab53 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -20,7 +20,6 @@ package org.apache.spark.shuffle.sort
 import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
 import org.apache.spark.storage.ShuffleBlockId
 import org.apache.spark.util.collection.ExternalSorter
@@ -36,7 +35,7 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   private val blockManager = SparkEnv.get.blockManager
 
-  private var sorter: SortShuffleFileWriter[K, V] = null
+  private var sorter: ExternalSorter[K, V, _] = null
 
   // Are we in the process of stopping? Because map tasks can call stop() with success = true
   // and then call stop() with success = false if they get an exception, we want to make sure
@@ -54,15 +53,6 @@ private[spark] class SortShuffleWriter[K, V, C](
       require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
       new ExternalSorter[K, V, C](
         dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
-    } else if (SortShuffleWriter.shouldBypassMergeSort(
-        SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) {
-      // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
-      // need local aggregation and sorting, write numPartitions files directly and just concatenate
-      // them at the end. This avoids doing serialization and deserialization twice to merge
-      // together the spilled files, which would happen with the normal code path. The downside is
-      // having multiple files open at a time and thus more memory allocated to buffers.
-      new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner,
-        writeMetrics, Serializer.getSerializer(dep.serializer))
     } else {
       // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
       // care whether the keys get sorted in each partition; that will be done on the reduce side
@@ -111,12 +101,14 @@ private[spark] class SortShuffleWriter[K, V, C](
 }
 
 private[spark] object SortShuffleWriter {
-  def shouldBypassMergeSort(
-      conf: SparkConf,
-      numPartitions: Int,
-      aggregator: Option[Aggregator[_, _, _]],
-      keyOrdering: Option[Ordering[_]]): Boolean = {
-    val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
-    numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty
+  def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
+    // We cannot bypass sorting if we need to do map-side aggregation.
+    if (dep.mapSideCombine) {
+      require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
+      false
+    } else {
+      val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+      dep.partitioner.numPartitions <= bypassMergeThreshold
+    }
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
deleted file mode 100644
index 75f22f642b..0000000000
--- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
+++ /dev/null
@@ -1,202 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe
-
-import java.util.Collections
-import java.util.concurrent.ConcurrentHashMap
-
-import org.apache.spark._
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle._
-import org.apache.spark.shuffle.sort.SortShuffleManager
-
-/**
- * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle.
- */
-private[spark] class UnsafeShuffleHandle[K, V](
-    shuffleId: Int,
-    numMaps: Int,
-    dependency: ShuffleDependency[K, V, V])
-  extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
-}
-
-private[spark] object UnsafeShuffleManager extends Logging {
-
-  /**
-   * The maximum number of shuffle output partitions that UnsafeShuffleManager supports.
-   */
-  val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
-
-  /**
-   * Helper method for determining whether a shuffle should use the optimized unsafe shuffle
-   * path or whether it should fall back to the original sort-based shuffle.
-   */
-  def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = {
-    val shufId = dependency.shuffleId
-    val serializer = Serializer.getSerializer(dependency.serializer)
-    if (!serializer.supportsRelocationOfSerializedObjects) {
-      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " +
-        s"${serializer.getClass.getName}, does not support object relocation")
-      false
-    } else if (dependency.aggregator.isDefined) {
-      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined")
-      false
-    } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) {
-      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " +
-        s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions")
-      false
-    } else {
-      log.debug(s"Can use UnsafeShuffle for shuffle $shufId")
-      true
-    }
-  }
-}
-
-/**
- * A shuffle implementation that uses directly-managed memory to implement several performance
- * optimizations for certain types of shuffles. In cases where the new performance optimizations
- * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those
- * shuffles.
- *
- * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold:
- *
- *  - The shuffle dependency specifies no aggregation or output ordering.
- *  - The shuffle serializer supports relocation of serialized values (this is currently supported
- *    by KryoSerializer and Spark SQL's custom serializers).
- *  - The shuffle produces fewer than 16777216 output partitions.
- *  - No individual record is larger than 128 MB when serialized.
- *
- * In addition, extra spill-merging optimizations are automatically applied when the shuffle
- * compression codec supports concatenation of serialized streams. This is currently supported by
- * Spark's LZF serializer.
- *
- * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager.
- * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
- * written to a single map output file. Reducers fetch contiguous regions of this file in order to
- * read their portion of the map output. In cases where the map output data is too large to fit in
- * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
- * to produce the final output file.
- *
- * UnsafeShuffleManager optimizes this process in several ways:
- *
- *  - Its sort operates on serialized binary data rather than Java objects, which reduces memory
- *    consumption and GC overheads. This optimization requires the record serializer to have certain
- *    properties to allow serialized records to be re-ordered without requiring deserialization.
- *    See SPARK-4550, where this optimization was first proposed and implemented, for more details.
- *
- *  - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts
- *    arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
- *    record in the sorting array, this fits more of the array into cache.
- *
- *  - The spill merging procedure operates on blocks of serialized records that belong to the same
- *    partition and does not need to deserialize records during the merge.
- *
- *  - When the spill compression codec supports concatenation of compressed data, the spill merge
- *    simply concatenates the serialized and compressed spill partitions to produce the final output
- *    partition.  This allows efficient data copying methods, like NIO's `transferTo`, to be used
- *    and avoids the need to allocate decompression or copying buffers during the merge.
- *
- * For more details on UnsafeShuffleManager's design, see SPARK-7081.
- */
-private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
-
-  if (!conf.getBoolean("spark.shuffle.spill", true)) {
-    logWarning(
-      "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " +
-      "manager; its optimized shuffles will continue to spill to disk when necessary.")
-  }
-
-  private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)
-  private[this] val shufflesThatFellBackToSortShuffle =
-    Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]())
-  private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]()
-
-  /**
-   * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
-   */
-  override def registerShuffle[K, V, C](
-      shuffleId: Int,
-      numMaps: Int,
-      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
-    if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) {
-      new UnsafeShuffleHandle[K, V](
-        shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
-    } else {
-      new BaseShuffleHandle(shuffleId, numMaps, dependency)
-    }
-  }
-
-  /**
-   * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
-   * Called on executors by reduce tasks.
-   */
-  override def getReader[K, C](
-      handle: ShuffleHandle,
-      startPartition: Int,
-      endPartition: Int,
-      context: TaskContext): ShuffleReader[K, C] = {
-    sortShuffleManager.getReader(handle, startPartition, endPartition, context)
-  }
-
-  /** Get a writer for a given partition. Called on executors by map tasks. */
-  override def getWriter[K, V](
-      handle: ShuffleHandle,
-      mapId: Int,
-      context: TaskContext): ShuffleWriter[K, V] = {
-    handle match {
-      case unsafeShuffleHandle: UnsafeShuffleHandle[K @unchecked, V @unchecked] =>
-        numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps)
-        val env = SparkEnv.get
-        new UnsafeShuffleWriter(
-          env.blockManager,
-          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
-          context.taskMemoryManager(),
-          env.shuffleMemoryManager,
-          unsafeShuffleHandle,
-          mapId,
-          context,
-          env.conf)
-      case other =>
-        shufflesThatFellBackToSortShuffle.add(handle.shuffleId)
-        sortShuffleManager.getWriter(handle, mapId, context)
-    }
-  }
-
-  /** Remove a shuffle's metadata from the ShuffleManager. */
-  override def unregisterShuffle(shuffleId: Int): Boolean = {
-    if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) {
-      sortShuffleManager.unregisterShuffle(shuffleId)
-    } else {
-      Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps =>
-        (0 until numMaps).foreach { mapId =>
-          shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
-        }
-      }
-      true
-    }
-  }
-
-  override val shuffleBlockResolver: IndexShuffleBlockResolver = {
-    sortShuffleManager.shuffleBlockResolver
-  }
-
-  /** Shut down this ShuffleManager. */
-  override def stop(): Unit = {
-    sortShuffleManager.stop()
-  }
-}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
deleted file mode 100644
index ae60f3b0cb..0000000000
--- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.io.OutputStream
-
-import scala.collection.mutable.ArrayBuffer
-
-/**
- * A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The
- * advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts
- * of memory and needing to copy the full contents. The disadvantage is that the contents don't
- * occupy a contiguous segment of memory.
- */
-private[spark] class ChainedBuffer(chunkSize: Int) {
-
-  private val chunkSizeLog2: Int = java.lang.Long.numberOfTrailingZeros(
-    java.lang.Long.highestOneBit(chunkSize))
-  assert((1 << chunkSizeLog2) == chunkSize,
-    s"ChainedBuffer chunk size $chunkSize must be a power of two")
-  private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]()
-  private var _size: Long = 0
-
-  /**
-   * Feed bytes from this buffer into a DiskBlockObjectWriter.
-   *
-   * @param pos Offset in the buffer to read from.
-   * @param os OutputStream to read into.
-   * @param len Number of bytes to read.
-   */
-  def read(pos: Long, os: OutputStream, len: Int): Unit = {
-    if (pos + len > _size) {
-      throw new IndexOutOfBoundsException(
-        s"Read of $len bytes at position $pos would go past size ${_size} of buffer")
-    }
-    var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
-    var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
-    var written: Int = 0
-    while (written < len) {
-      val toRead: Int = math.min(len - written, chunkSize - posInChunk)
-      os.write(chunks(chunkIndex), posInChunk, toRead)
-      written += toRead
-      chunkIndex += 1
-      posInChunk = 0
-    }
-  }
-
-  /**
-   * Read bytes from this buffer into a byte array.
-   *
-   * @param pos Offset in the buffer to read from.
-   * @param bytes Byte array to read into.
-   * @param offs Offset in the byte array to read to.
-   * @param len Number of bytes to read.
-   */
-  def read(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = {
-    if (pos + len > _size) {
-      throw new IndexOutOfBoundsException(
-        s"Read of $len bytes at position $pos would go past size of buffer")
-    }
-    var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
-    var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
-    var written: Int = 0
-    while (written < len) {
-      val toRead: Int = math.min(len - written, chunkSize - posInChunk)
-      System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead)
-      written += toRead
-      chunkIndex += 1
-      posInChunk = 0
-    }
-  }
-
-  /**
-   * Write bytes from a byte array into this buffer.
-   *
-   * @param pos Offset in the buffer to write to.
-   * @param bytes Byte array to write from.
-   * @param offs Offset in the byte array to write from.
-   * @param len Number of bytes to write.
-   */
-  def write(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = {
-    if (pos > _size) {
-      throw new IndexOutOfBoundsException(
-        s"Write at position $pos starts after end of buffer ${_size}")
-    }
-    // Grow if needed
-    val endChunkIndex: Int = ((pos + len - 1) >> chunkSizeLog2).toInt
-    while (endChunkIndex >= chunks.length) {
-      chunks += new Array[Byte](chunkSize)
-    }
-
-    var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
-    var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
-    var written: Int = 0
-    while (written < len) {
-      val toWrite: Int = math.min(len - written, chunkSize - posInChunk)
-      System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite)
-      written += toWrite
-      chunkIndex += 1
-      posInChunk = 0
-    }
-
-    _size = math.max(_size, pos + len)
-  }
-
-  /**
-   * Total size of buffer that can be written to without allocating additional memory.
-   */
-  def capacity: Long = chunks.size.toLong * chunkSize
-
-  /**
-   * Size of the logical buffer.
-   */
-  def size: Long = _size
-}
-
-/**
- * Output stream that writes to a ChainedBuffer.
- */
-private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream {
-  private var pos: Long = 0
-
-  override def write(b: Int): Unit = {
-    throw new UnsupportedOperationException()
-  }
-
-  override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
-    chainedBuffer.write(pos, bytes, offs, len)
-    pos += len
-  }
-}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 749be34d8e..c48c453a90 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -29,7 +29,6 @@ import com.google.common.io.ByteStreams
 import org.apache.spark._
 import org.apache.spark.serializer._
 import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter}
 import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
 
 /**
@@ -69,8 +68,8 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
  * At a high level, this class works internally as follows:
  *
  * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if
- *   we want to combine by key, or a PartitionedSerializedPairBuffer or PartitionedPairBuffer if we
- *   don't. Inside these buffers, we sort elements by partition ID and then possibly also by key.
+ *   we want to combine by key, or a PartitionedPairBuffer if we don't.
+ *   Inside these buffers, we sort elements by partition ID and then possibly also by key.
  *   To avoid calling the partitioner multiple times with each key, we store the partition ID
  *   alongside each record.
  *
@@ -93,8 +92,7 @@ private[spark] class ExternalSorter[K, V, C](
     ordering: Option[Ordering[K]] = None,
     serializer: Option[Serializer] = None)
   extends Logging
-  with Spillable[WritablePartitionedPairCollection[K, C]]
-  with SortShuffleFileWriter[K, V] {
+  with Spillable[WritablePartitionedPairCollection[K, C]] {
 
   private val conf = SparkEnv.get.conf
 
@@ -104,13 +102,6 @@ private[spark] class ExternalSorter[K, V, C](
     if (shouldPartition) partitioner.get.getPartition(key) else 0
   }
 
-  // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class.
-  // As a sanity check, make sure that we're not handling a shuffle which should use that path.
-  if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) {
-    throw new IllegalArgumentException("ExternalSorter should not be used to handle "
-      + " a sort that the BypassMergeSortShuffleWriter should handle")
-  }
-
   private val blockManager = SparkEnv.get.blockManager
   private val diskBlockManager = blockManager.diskBlockManager
   private val ser = Serializer.getSerializer(serializer)
@@ -128,23 +119,11 @@ private[spark] class ExternalSorter[K, V, C](
   // grow internal data structures by growing + copying every time the number of objects doubles.
   private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
 
-  private val useSerializedPairBuffer =
-    ordering.isEmpty &&
-      conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
-      ser.supportsRelocationOfSerializedObjects
-  private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
-  private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = {
-    if (useSerializedPairBuffer) {
-      new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance)
-    } else {
-      new PartitionedPairBuffer[K, C]
-    }
-  }
   // Data structures to store in-memory objects before we spill. Depending on whether we have an
   // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
   // store them in an array buffer.
   private var map = new PartitionedAppendOnlyMap[K, C]
-  private var buffer = newBuffer()
+  private var buffer = new PartitionedPairBuffer[K, C]
 
   // Total spilling statistics
   private var _diskBytesSpilled = 0L
@@ -192,7 +171,7 @@ private[spark] class ExternalSorter[K, V, C](
    */
   private[spark] def numSpills: Int = spills.size
 
-  override def insertAll(records: Iterator[Product2[K, V]]): Unit = {
+  def insertAll(records: Iterator[Product2[K, V]]): Unit = {
     // TODO: stop combining if we find that the reduction factor isn't high
     val shouldCombine = aggregator.isDefined
 
@@ -236,7 +215,7 @@ private[spark] class ExternalSorter[K, V, C](
     } else {
       estimatedSize = buffer.estimateSize()
       if (maybeSpill(buffer, estimatedSize)) {
-        buffer = newBuffer()
+        buffer = new PartitionedPairBuffer[K, C]
       }
     }
 
@@ -659,7 +638,7 @@ private[spark] class ExternalSorter[K, V, C](
    * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
    * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
    */
-  override def writePartitionedFile(
+  def writePartitionedFile(
       blockId: BlockId,
       context: TaskContext,
       outputFile: File): Array[Long] = {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
deleted file mode 100644
index 87a786b02d..0000000000
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
+++ /dev/null
@@ -1,273 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.io.InputStream
-import java.nio.IntBuffer
-import java.util.Comparator
-
-import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance}
-import org.apache.spark.storage.DiskBlockObjectWriter
-import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
-
-/**
- * Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes
- * its records upon insert and stores them as raw bytes.
- *
- * We use two data-structures to store the contents. The serialized records are stored in a
- * ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a
- * metadata buffer that stores pointers into the data buffer as well as the partition ID of each
- * record. Each entry in the metadata buffer takes up a fixed amount of space.
- *
- * Sorting the collection means swapping entries in the metadata buffer - the record buffer need not
- * be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can
- * happen without following any pointers, which should minimize cache misses.
- *
- * Currently, only sorting by partition is supported.
- *
- * Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across
- * two integers:
- *
- *   +-------------+------------+------------+-------------+
- *   |         keyStart         | keyValLen  | partitionId |
- *   +-------------+------------+------------+-------------+
- *
- * The buffer can support up to `536870911 (2 ^ 29 - 1)` records.
- *
- * @param metaInitialRecords The initial number of entries in the metadata buffer.
- * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records.
- * @param serializerInstance the serializer used for serializing inserted records.
- */
-private[spark] class PartitionedSerializedPairBuffer[K, V](
-    metaInitialRecords: Int,
-    kvBlockSize: Int,
-    serializerInstance: SerializerInstance)
-  extends WritablePartitionedPairCollection[K, V] with SizeTracker {
-
-  if (serializerInstance.isInstanceOf[JavaSerializerInstance]) {
-    throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" +
-      " Java-serialized objects.")
-  }
-
-  require(metaInitialRecords <= MAXIMUM_RECORDS,
-    s"Can't make capacity bigger than ${MAXIMUM_RECORDS} records")
-  private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE)
-
-  private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize)
-  private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer)
-  private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream)
-
-  def insert(partition: Int, key: K, value: V): Unit = {
-    if (metaBuffer.position == metaBuffer.capacity) {
-      growMetaBuffer()
-    }
-
-    val keyStart = kvBuffer.size
-    kvSerializationStream.writeKey[Any](key)
-    kvSerializationStream.writeValue[Any](value)
-    kvSerializationStream.flush()
-    val keyValLen = (kvBuffer.size - keyStart).toInt
-
-    // keyStart, a long, gets split across two ints
-    metaBuffer.put(keyStart.toInt)
-    metaBuffer.put((keyStart >> 32).toInt)
-    metaBuffer.put(keyValLen)
-    metaBuffer.put(partition)
-  }
-
-  /** Double the size of the array because we've reached capacity */
-  private def growMetaBuffer(): Unit = {
-    if (metaBuffer.capacity >= MAXIMUM_META_BUFFER_CAPACITY) {
-      throw new IllegalStateException(s"Can't insert more than ${MAXIMUM_RECORDS} records")
-    }
-    val newCapacity =
-      if (metaBuffer.capacity * 2 < 0 || metaBuffer.capacity * 2 > MAXIMUM_META_BUFFER_CAPACITY) {
-        // Overflow
-        MAXIMUM_META_BUFFER_CAPACITY
-      } else {
-        metaBuffer.capacity * 2
-      }
-    val newMetaBuffer = IntBuffer.allocate(newCapacity)
-    newMetaBuffer.put(metaBuffer.array)
-    metaBuffer = newMetaBuffer
-  }
-
-  /** Iterate through the data in a given order. For this class this is not really destructive. */
-  override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
-    : Iterator[((Int, K), V)] = {
-    sort(keyComparator)
-    val is = orderedInputStream
-    val deserStream = serializerInstance.deserializeStream(is)
-    new Iterator[((Int, K), V)] {
-      var metaBufferPos = 0
-      def hasNext: Boolean = metaBufferPos < metaBuffer.position
-      def next(): ((Int, K), V) = {
-        val key = deserStream.readKey[Any]().asInstanceOf[K]
-        val value = deserStream.readValue[Any]().asInstanceOf[V]
-        val partition = metaBuffer.get(metaBufferPos + PARTITION)
-        metaBufferPos += RECORD_SIZE
-        ((partition, key), value)
-      }
-    }
-  }
-
-  override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity
-
-  override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
-    : WritablePartitionedIterator = {
-    sort(keyComparator)
-    new WritablePartitionedIterator {
-      // current position in the meta buffer in ints
-      var pos = 0
-
-      def writeNext(writer: DiskBlockObjectWriter): Unit = {
-        val keyStart = getKeyStartPos(metaBuffer, pos)
-        val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN)
-        pos += RECORD_SIZE
-        kvBuffer.read(keyStart, writer, keyValLen)
-        writer.recordWritten()
-      }
-      def nextPartition(): Int = metaBuffer.get(pos + PARTITION)
-      def hasNext(): Boolean = pos < metaBuffer.position
-    }
-  }
-
-  // Visible for testing
-  def orderedInputStream: OrderedInputStream = {
-    new OrderedInputStream(metaBuffer, kvBuffer)
-  }
-
-  private def sort(keyComparator: Option[Comparator[K]]): Unit = {
-    val comparator = if (keyComparator.isEmpty) {
-      new Comparator[Int]() {
-        def compare(partition1: Int, partition2: Int): Int = {
-          partition1 - partition2
-        }
-      }
-    } else {
-      throw new UnsupportedOperationException()
-    }
-
-    val sorter = new Sorter(new SerializedSortDataFormat)
-    sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator)
-  }
-}
-
-private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer)
-    extends InputStream {
-
-  import PartitionedSerializedPairBuffer._
-
-  private var metaBufferPos = 0
-  private var kvBufferPos =
-    if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0
-
-  override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length)
-
-  override def read(bytes: Array[Byte], offs: Int, len: Int): Int = {
-    if (metaBufferPos >= metaBuffer.position) {
-      return -1
-    }
-    val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) -
-      (kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt
-    val toRead = math.min(bytesRemainingInRecord, len)
-    kvBuffer.read(kvBufferPos, bytes, offs, toRead)
-    if (toRead == bytesRemainingInRecord) {
-      metaBufferPos += RECORD_SIZE
-      if (metaBufferPos < metaBuffer.position) {
-        kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos)
-      }
-    } else {
-      kvBufferPos += toRead
-    }
-    toRead
-  }
-
-  override def read(): Int = {
-    throw new UnsupportedOperationException()
-  }
-}
-
-private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] {
-
-  private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE)
-
-  /** Return the sort key for the element at the given index. */
-  override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = {
-    metaBuffer.get(pos * RECORD_SIZE + PARTITION)
-  }
-
-  /** Swap two elements. */
-  override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = {
-    val iOff = pos0 * RECORD_SIZE
-    val jOff = pos1 * RECORD_SIZE
-    System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE)
-    System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE)
-    System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE)
-  }
-
-  /** Copy a single element from src(srcPos) to dst(dstPos). */
-  override def copyElement(
-      src: IntBuffer,
-      srcPos: Int,
-      dst: IntBuffer,
-      dstPos: Int): Unit = {
-    val srcOff = srcPos * RECORD_SIZE
-    val dstOff = dstPos * RECORD_SIZE
-    System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE)
-  }
-
-  /**
-   * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos.
-   * Overlapping ranges are allowed.
-   */
-  override def copyRange(
-      src: IntBuffer,
-      srcPos: Int,
-      dst: IntBuffer,
-      dstPos: Int,
-      length: Int): Unit = {
-    val srcOff = srcPos * RECORD_SIZE
-    val dstOff = dstPos * RECORD_SIZE
-    System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length)
-  }
-
-  /**
-   * Allocates a Buffer that can hold up to 'length' elements.
-   * All elements of the buffer should be considered invalid until data is explicitly copied in.
-   */
-  override def allocate(length: Int): IntBuffer = {
-    IntBuffer.allocate(length * RECORD_SIZE)
-  }
-}
-
-private object PartitionedSerializedPairBuffer {
-  val KEY_START = 0 // keyStart, a long, gets split across two ints
-  val KEY_VAL_LEN = 2
-  val PARTITION = 3
-  val RECORD_SIZE = PARTITION + 1 // num ints of metadata
-
-  val MAXIMUM_RECORDS = Int.MaxValue / RECORD_SIZE // (2 ^ 29) - 1
-  val MAXIMUM_META_BUFFER_CAPACITY = MAXIMUM_RECORDS * RECORD_SIZE // (2 ^ 31) - 4
-
-  def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = {
-    val lower32 = metaBuffer.get(metaBufferPos + KEY_START)
-    val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1)
-    (upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL)
-  }
-}
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
similarity index 96%
rename from core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
rename to core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
index 934b7e0305..232ae4d926 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
@@ -15,8 +15,9 @@
  * limitations under the License.
  */
 
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
 
+import org.apache.spark.shuffle.sort.PackedRecordPointer;
 import org.junit.Test;
 import static org.junit.Assert.*;
 
@@ -24,7 +25,7 @@ import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
 import org.apache.spark.unsafe.memory.MemoryAllocator;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.unsafe.memory.TaskMemoryManager;
-import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*;
+import static org.apache.spark.shuffle.sort.PackedRecordPointer.*;
 
 public class PackedRecordPointerSuite {
 
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
similarity index 87%
rename from core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
rename to core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 40fefe2c9d..1ef3c5ff64 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
 
 import java.util.Arrays;
 import java.util.Random;
@@ -30,7 +30,7 @@ import org.apache.spark.unsafe.memory.MemoryAllocator;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.unsafe.memory.TaskMemoryManager;
 
-public class UnsafeShuffleInMemorySorterSuite {
+public class ShuffleInMemorySorterSuite {
 
   private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
     final byte[] strBytes = new byte[strLength];
@@ -40,8 +40,8 @@ public class UnsafeShuffleInMemorySorterSuite {
 
   @Test
   public void testSortingEmptyInput() {
-    final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100);
-    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+    final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100);
+    final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
     assert(!iter.hasNext());
   }
 
@@ -62,7 +62,7 @@ public class UnsafeShuffleInMemorySorterSuite {
       new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
     final MemoryBlock dataPage = memoryManager.allocatePage(2048);
     final Object baseObject = dataPage.getBaseObject();
-    final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
+    final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
     final HashPartitioner hashPartitioner = new HashPartitioner(4);
 
     // Write the records into the data page and store pointers into the sorter
@@ -79,7 +79,7 @@ public class UnsafeShuffleInMemorySorterSuite {
     }
 
     // Sort the records
-    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+    final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
     int prevPartitionId = -1;
     Arrays.sort(dataToSort);
     for (int i = 0; i < dataToSort.length; i++) {
@@ -103,7 +103,7 @@ public class UnsafeShuffleInMemorySorterSuite {
 
   @Test
   public void testSortingManyNumbers() throws Exception {
-    UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
+    ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
     int[] numbersToSort = new int[128000];
     Random random = new Random(16);
     for (int i = 0; i < numbersToSort.length; i++) {
@@ -112,7 +112,7 @@ public class UnsafeShuffleInMemorySorterSuite {
     }
     Arrays.sort(numbersToSort);
     int[] sorterResult = new int[numbersToSort.length];
-    UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+    ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
     int j = 0;
     while (iter.hasNext()) {
       iter.loadNext();
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
similarity index 98%
rename from core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
rename to core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index d218344cd4..29d9823b1f 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
 
 import java.io.*;
 import java.nio.ByteBuffer;
@@ -23,7 +23,6 @@ import java.util.*;
 
 import scala.*;
 import scala.collection.Iterator;
-import scala.reflect.ClassTag;
 import scala.runtime.AbstractFunction1;
 
 import com.google.common.collect.Iterators;
@@ -56,6 +55,7 @@ import org.apache.spark.serializer.*;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
 import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.shuffle.sort.SerializedShuffleHandle;
 import org.apache.spark.storage.*;
 import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
 import org.apache.spark.unsafe.memory.MemoryAllocator;
@@ -204,7 +204,7 @@ public class UnsafeShuffleWriterSuite {
       shuffleBlockResolver,
       taskMemoryManager,
       shuffleMemoryManager,
-      new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
+      new SerializedShuffleHandle<Object, Object>(0, 1, shuffleDep),
       0, // map id
       taskContext,
       conf
@@ -461,7 +461,7 @@ public class UnsafeShuffleWriterSuite {
     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
     final ArrayList<Product2<Object, Object>> dataToWrite =
       new ArrayList<Product2<Object, Object>>();
-    final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
+    final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
     new Random(42).nextBytes(bytes);
     dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
     writer.write(dataToWrite.iterator());
@@ -516,7 +516,7 @@ public class UnsafeShuffleWriterSuite {
         shuffleBlockResolver,
         taskMemoryManager,
         shuffleMemoryManager,
-        new UnsafeShuffleHandle<>(0, 1, shuffleDep),
+        new SerializedShuffleHandle<>(0, 1, shuffleDep),
         0, // map id
         taskContext,
         conf);
diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
index 63358172ea..b8ab227517 100644
--- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
@@ -17,13 +17,78 @@
 
 package org.apache.spark
 
+import java.io.File
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.FileUtils
+import org.apache.commons.io.filefilter.TrueFileFilter
 import org.scalatest.BeforeAndAfterAll
 
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.util.Utils
+
 class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
 
   // This test suite should run all tests in ShuffleSuite with sort-based shuffle.
 
+  private var tempDir: File = _
+
   override def beforeAll() {
     conf.set("spark.shuffle.manager", "sort")
   }
+
+  override def beforeEach(): Unit = {
+    tempDir = Utils.createTempDir()
+    conf.set("spark.local.dir", tempDir.getAbsolutePath)
+  }
+
+  override def afterEach(): Unit = {
+    try {
+      Utils.deleteRecursively(tempDir)
+    } finally {
+      super.afterEach()
+    }
+  }
+
+  test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") {
+    sc = new SparkContext("local", "test", conf)
+    // Create a shuffled RDD and verify that it actually uses the new serialized map output path
+    val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+    val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+      .setSerializer(new KryoSerializer(conf))
+    val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+    assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep))
+    ensureFilesAreCleanedUp(shuffledRdd)
+  }
+
+  test("SortShuffleManager properly cleans up files for shuffles that use the deserialized path") {
+    sc = new SparkContext("local", "test", conf)
+    // Create a shuffled RDD and verify that it actually uses the old deserialized map output path
+    val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+    val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+      .setSerializer(new JavaSerializer(conf))
+    val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+    assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep))
+    ensureFilesAreCleanedUp(shuffledRdd)
+  }
+
+  private def ensureFilesAreCleanedUp(shuffledRdd: ShuffledRDD[_, _, _]): Unit = {
+    def getAllFiles: Set[File] =
+      FileUtils.listFiles(tempDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+    val filesBeforeShuffle = getAllFiles
+    // Force the shuffle to be performed
+    shuffledRdd.count()
+    // Ensure that the shuffle actually created files that will need to be cleaned up
+    val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+    filesCreatedByShuffle.map(_.getName) should be
+    Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+    // Check that the cleanup actually removes the files
+    sc.env.blockManager.master.removeShuffle(0, blocking = true)
+    for (file <- filesCreatedByShuffle) {
+      assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+    }
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 5b01ddb298..3816b8c4a0 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -1062,10 +1062,10 @@ class DAGSchedulerSuite
    */
   test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") {
     val firstRDD = new MyRDD(sc, 3, Nil)
-    val firstShuffleDep = new ShuffleDependency(firstRDD, null)
+    val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2))
     val firstShuffleId = firstShuffleDep.shuffleId
     val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep))
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
     val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
     submit(reduceRdd, Array(0))
 
@@ -1175,7 +1175,7 @@ class DAGSchedulerSuite
    */
   test("register map outputs correctly after ExecutorLost and task Resubmitted") {
     val firstRDD = new MyRDD(sc, 3, Nil)
-    val firstShuffleDep = new ShuffleDependency(firstRDD, null)
+    val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2))
     val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep))
     submit(reduceRdd, Array(0))
 
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 341f56df2d..b92a302806 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -33,7 +33,8 @@ import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark._
 import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics}
-import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer}
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.serializer.{JavaSerializer, SerializerInstance}
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
 
@@ -42,25 +43,31 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
   @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
   @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _
   @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _
 
   private var taskMetrics: TaskMetrics = _
-  private var shuffleWriteMetrics: ShuffleWriteMetrics = _
   private var tempDir: File = _
   private var outputFile: File = _
   private val conf: SparkConf = new SparkConf(loadDefaults = false)
   private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]()
   private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File]
-  private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0)
-  private val serializer: Serializer = new JavaSerializer(conf)
+  private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _
 
   override def beforeEach(): Unit = {
     tempDir = Utils.createTempDir()
     outputFile = File.createTempFile("shuffle", null, tempDir)
-    shuffleWriteMetrics = new ShuffleWriteMetrics
     taskMetrics = new TaskMetrics
-    taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
     MockitoAnnotations.initMocks(this)
+    shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int](
+      shuffleId = 0,
+      numMaps = 2,
+      dependency = dependency
+    )
+    when(dependency.partitioner).thenReturn(new HashPartitioner(7))
+    when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf)))
     when(taskContext.taskMetrics()).thenReturn(taskMetrics)
+    when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
     when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
     when(blockManager.getDiskWriter(
       any[BlockId],
@@ -107,18 +114,20 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
 
   test("write empty iterator") {
     val writer = new BypassMergeSortShuffleWriter[Int, Int](
-      new SparkConf(loadDefaults = false),
       blockManager,
-      new HashPartitioner(7),
-      shuffleWriteMetrics,
-      serializer
+      blockResolver,
+      shuffleHandle,
+      0, // MapId
+      taskContext,
+      conf
     )
-    writer.insertAll(Iterator.empty)
-    val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
-    assert(partitionLengths.sum === 0)
+    writer.write(Iterator.empty)
+    writer.stop( /* success = */ true)
+    assert(writer.getPartitionLengths.sum === 0)
     assert(outputFile.exists())
     assert(outputFile.length() === 0)
     assert(temporaryFilesCreated.isEmpty)
+    val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get
     assert(shuffleWriteMetrics.shuffleBytesWritten === 0)
     assert(shuffleWriteMetrics.shuffleRecordsWritten === 0)
     assert(taskMetrics.diskBytesSpilled === 0)
@@ -129,17 +138,19 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
     def records: Iterator[(Int, Int)] =
       Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
     val writer = new BypassMergeSortShuffleWriter[Int, Int](
-      new SparkConf(loadDefaults = false),
       blockManager,
-      new HashPartitioner(7),
-      shuffleWriteMetrics,
-      serializer
+      blockResolver,
+      shuffleHandle,
+      0, // MapId
+      taskContext,
+      conf
     )
-    writer.insertAll(records)
+    writer.write(records)
+    writer.stop( /* success = */ true)
     assert(temporaryFilesCreated.nonEmpty)
-    val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
-    assert(partitionLengths.sum === outputFile.length())
+    assert(writer.getPartitionLengths.sum === outputFile.length())
     assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted
+    val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get
     assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length())
     assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length)
     assert(taskMetrics.diskBytesSpilled === 0)
@@ -148,14 +159,15 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
 
   test("cleanup of intermediate files after errors") {
     val writer = new BypassMergeSortShuffleWriter[Int, Int](
-      new SparkConf(loadDefaults = false),
       blockManager,
-      new HashPartitioner(7),
-      shuffleWriteMetrics,
-      serializer
+      blockResolver,
+      shuffleHandle,
+      0, // MapId
+      taskContext,
+      conf
     )
     intercept[SparkException] {
-      writer.insertAll((0 until 100000).iterator.map(i => {
+      writer.write((0 until 100000).iterator.map(i => {
         if (i == 99990) {
           throw new SparkException("Intentional failure")
         }
@@ -163,7 +175,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
       }))
     }
     assert(temporaryFilesCreated.nonEmpty)
-    writer.stop()
+    writer.stop( /* success = */ false)
     assert(temporaryFilesCreated.count(_.exists()) === 0)
   }
 
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
similarity index 80%
rename from core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
rename to core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
index 6727934d8c..8744a072cb 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.spark.shuffle.unsafe
+package org.apache.spark.shuffle.sort
 
 import org.mockito.Mockito._
 import org.mockito.invocation.InvocationOnMock
@@ -29,9 +29,9 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer}
  * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are
  * performed in other suites.
  */
-class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
+class SortShuffleManagerSuite extends SparkFunSuite with Matchers {
 
-  import UnsafeShuffleManager.canUseUnsafeShuffle
+  import SortShuffleManager.canUseSerializedShuffle
 
   private class RuntimeExceptionAnswer extends Answer[Object] {
     override def answer(invocation: InvocationOnMock): Object = {
@@ -55,10 +55,10 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
     dep
   }
 
-  test("supported shuffle dependencies") {
+  test("supported shuffle dependencies for serialized shuffle") {
     val kryo = Some(new KryoSerializer(new SparkConf()))
 
-    assert(canUseUnsafeShuffle(shuffleDep(
+    assert(canUseSerializedShuffle(shuffleDep(
       partitioner = new HashPartitioner(2),
       serializer = kryo,
       keyOrdering = None,
@@ -68,7 +68,7 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
 
     val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]])
     when(rangePartitioner.numPartitions).thenReturn(2)
-    assert(canUseUnsafeShuffle(shuffleDep(
+    assert(canUseSerializedShuffle(shuffleDep(
       partitioner = rangePartitioner,
       serializer = kryo,
       keyOrdering = None,
@@ -77,7 +77,7 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
     )))
 
     // Shuffles with key orderings are supported as long as no aggregator is specified
-    assert(canUseUnsafeShuffle(shuffleDep(
+    assert(canUseSerializedShuffle(shuffleDep(
       partitioner = new HashPartitioner(2),
       serializer = kryo,
       keyOrdering = Some(mock(classOf[Ordering[Any]])),
@@ -87,12 +87,12 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
 
   }
 
-  test("unsupported shuffle dependencies") {
+  test("unsupported shuffle dependencies for serialized shuffle") {
     val kryo = Some(new KryoSerializer(new SparkConf()))
     val java = Some(new JavaSerializer(new SparkConf()))
 
     // We only support serializers that support object relocation
-    assert(!canUseUnsafeShuffle(shuffleDep(
+    assert(!canUseSerializedShuffle(shuffleDep(
       partitioner = new HashPartitioner(2),
       serializer = java,
       keyOrdering = None,
@@ -100,9 +100,11 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
       mapSideCombine = false
     )))
 
-    // We do not support shuffles with more than 16 million output partitions
-    assert(!canUseUnsafeShuffle(shuffleDep(
-      partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1),
+    // The serialized shuffle path do not support shuffles with more than 16 million output
+    // partitions, due to a limitation in its sorter implementation.
+    assert(!canUseSerializedShuffle(shuffleDep(
+      partitioner = new HashPartitioner(
+        SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE + 1),
       serializer = kryo,
       keyOrdering = None,
       aggregator = None,
@@ -110,14 +112,14 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
     )))
 
     // We do not support shuffles that perform aggregation
-    assert(!canUseUnsafeShuffle(shuffleDep(
+    assert(!canUseSerializedShuffle(shuffleDep(
       partitioner = new HashPartitioner(2),
       serializer = kryo,
       keyOrdering = None,
       aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
       mapSideCombine = false
     )))
-    assert(!canUseUnsafeShuffle(shuffleDep(
+    assert(!canUseSerializedShuffle(shuffleDep(
       partitioner = new HashPartitioner(2),
       serializer = kryo,
       keyOrdering = Some(mock(classOf[Ordering[Any]])),
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
deleted file mode 100644
index 34b4984f12..0000000000
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.sort
-
-import org.mockito.Mockito._
-
-import org.apache.spark.{Aggregator, SparkConf, SparkFunSuite}
-
-class SortShuffleWriterSuite extends SparkFunSuite {
-
-  import SortShuffleWriter._
-
-  test("conditions for bypassing merge-sort") {
-    val conf = new SparkConf(loadDefaults = false)
-    val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS)
-    val ord = implicitly[Ordering[Int]]
-
-    // Numbers of partitions that are above and below the default bypassMergeThreshold
-    val FEW_PARTITIONS = 50
-    val MANY_PARTITIONS = 10000
-
-    // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high
-    assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None))
-    assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None))
-
-    // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions
-    assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord)))
-    assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None))
-  }
-}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
deleted file mode 100644
index 259020a2dd..0000000000
--- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe
-
-import java.io.File
-
-import scala.collection.JavaConverters._
-
-import org.apache.commons.io.FileUtils
-import org.apache.commons.io.filefilter.TrueFileFilter
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite}
-import org.apache.spark.rdd.ShuffledRDD
-import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
-import org.apache.spark.util.Utils
-
-class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
-
-  // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle.
-
-  override def beforeAll() {
-    conf.set("spark.shuffle.manager", "tungsten-sort")
-  }
-
-  test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") {
-    val tmpDir = Utils.createTempDir()
-    try {
-      val myConf = conf.clone()
-        .set("spark.local.dir", tmpDir.getAbsolutePath)
-      sc = new SparkContext("local", "test", myConf)
-      // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path
-      val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
-      val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
-        .setSerializer(new KryoSerializer(myConf))
-      val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
-      assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
-      def getAllFiles: Set[File] =
-        FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
-      val filesBeforeShuffle = getAllFiles
-      // Force the shuffle to be performed
-      shuffledRdd.count()
-      // Ensure that the shuffle actually created files that will need to be cleaned up
-      val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
-      filesCreatedByShuffle.map(_.getName) should be
-        Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
-      // Check that the cleanup actually removes the files
-      sc.env.blockManager.master.removeShuffle(0, blocking = true)
-      for (file <- filesCreatedByShuffle) {
-        assert (!file.exists(), s"Shuffle file $file was not cleaned up")
-      }
-    } finally {
-      Utils.deleteRecursively(tmpDir)
-    }
-  }
-
-  test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") {
-    val tmpDir = Utils.createTempDir()
-    try {
-      val myConf = conf.clone()
-        .set("spark.local.dir", tmpDir.getAbsolutePath)
-      sc = new SparkContext("local", "test", myConf)
-      // Create a shuffled RDD and verify that it will actually use the old SortShuffle path
-      val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
-      val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
-        .setSerializer(new JavaSerializer(myConf))
-      val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
-      assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
-      def getAllFiles: Set[File] =
-        FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
-      val filesBeforeShuffle = getAllFiles
-      // Force the shuffle to be performed
-      shuffledRdd.count()
-      // Ensure that the shuffle actually created files that will need to be cleaned up
-      val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
-      filesCreatedByShuffle.map(_.getName) should be
-        Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
-      // Check that the cleanup actually removes the files
-      sc.env.blockManager.master.removeShuffle(0, blocking = true)
-      for (file <- filesCreatedByShuffle) {
-        assert (!file.exists(), s"Shuffle file $file was not cleaned up")
-      }
-    } finally {
-      Utils.deleteRecursively(tmpDir)
-    }
-  }
-}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
deleted file mode 100644
index 05306f4088..0000000000
--- a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
+++ /dev/null
@@ -1,144 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.nio.ByteBuffer
-
-import org.scalatest.Matchers._
-
-import org.apache.spark.SparkFunSuite
-
-class ChainedBufferSuite extends SparkFunSuite {
-  test("write and read at start") {
-    // write from start of source array
-    val buffer = new ChainedBuffer(8)
-    buffer.capacity should be (0)
-    verifyWriteAndRead(buffer, 0, 0, 0, 4)
-    buffer.capacity should be (8)
-
-    // write from middle of source array
-    verifyWriteAndRead(buffer, 0, 5, 0, 4)
-    buffer.capacity should be (8)
-
-    // read to middle of target array
-    verifyWriteAndRead(buffer, 0, 0, 5, 4)
-    buffer.capacity should be (8)
-
-    // write up to border
-    verifyWriteAndRead(buffer, 0, 0, 0, 8)
-    buffer.capacity should be (8)
-
-    // expand into second buffer
-    verifyWriteAndRead(buffer, 0, 0, 0, 12)
-    buffer.capacity should be (16)
-
-    // expand into multiple buffers
-    verifyWriteAndRead(buffer, 0, 0, 0, 28)
-    buffer.capacity should be (32)
-  }
-
-  test("write and read at middle") {
-    val buffer = new ChainedBuffer(8)
-
-    // fill to a middle point
-    verifyWriteAndRead(buffer, 0, 0, 0, 3)
-
-    // write from start of source array
-    verifyWriteAndRead(buffer, 3, 0, 0, 4)
-    buffer.capacity should be (8)
-
-    // write from middle of source array
-    verifyWriteAndRead(buffer, 3, 5, 0, 4)
-    buffer.capacity should be (8)
-
-    // read to middle of target array
-    verifyWriteAndRead(buffer, 3, 0, 5, 4)
-    buffer.capacity should be (8)
-
-    // write up to border
-    verifyWriteAndRead(buffer, 3, 0, 0, 5)
-    buffer.capacity should be (8)
-
-    // expand into second buffer
-    verifyWriteAndRead(buffer, 3, 0, 0, 12)
-    buffer.capacity should be (16)
-
-    // expand into multiple buffers
-    verifyWriteAndRead(buffer, 3, 0, 0, 28)
-    buffer.capacity should be (32)
-  }
-
-  test("write and read at later buffer") {
-    val buffer = new ChainedBuffer(8)
-
-    // fill to a middle point
-    verifyWriteAndRead(buffer, 0, 0, 0, 11)
-
-    // write from start of source array
-    verifyWriteAndRead(buffer, 11, 0, 0, 4)
-    buffer.capacity should be (16)
-
-    // write from middle of source array
-    verifyWriteAndRead(buffer, 11, 5, 0, 4)
-    buffer.capacity should be (16)
-
-    // read to middle of target array
-    verifyWriteAndRead(buffer, 11, 0, 5, 4)
-    buffer.capacity should be (16)
-
-    // write up to border
-    verifyWriteAndRead(buffer, 11, 0, 0, 5)
-    buffer.capacity should be (16)
-
-    // expand into second buffer
-    verifyWriteAndRead(buffer, 11, 0, 0, 12)
-    buffer.capacity should be (24)
-
-    // expand into multiple buffers
-    verifyWriteAndRead(buffer, 11, 0, 0, 28)
-    buffer.capacity should be (40)
-  }
-
-
-  // Used to make sure we're writing different bytes each time
-  var rangeStart = 0
-
-  /**
-   * @param buffer The buffer to write to and read from.
-   * @param offsetInBuffer The offset to write to in the buffer.
-   * @param offsetInSource The offset in the array that the bytes are written from.
-   * @param offsetInTarget The offset in the array to read the bytes into.
-   * @param length The number of bytes to read and write
-   */
-  def verifyWriteAndRead(
-      buffer: ChainedBuffer,
-      offsetInBuffer: Int,
-      offsetInSource: Int,
-      offsetInTarget: Int,
-      length: Int): Unit = {
-    val source = new Array[Byte](offsetInSource + length)
-    (rangeStart until rangeStart + length).map(_.toByte).copyToArray(source, offsetInSource)
-    buffer.write(offsetInBuffer, source, offsetInSource, length)
-    val target = new Array[Byte](offsetInTarget + length)
-    buffer.read(offsetInBuffer, target, offsetInTarget, length)
-    ByteBuffer.wrap(source, offsetInSource, length) should be
-      (ByteBuffer.wrap(target, offsetInTarget, length))
-
-    rangeStart += 100
-  }
-}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
deleted file mode 100644
index 3b67f62064..0000000000
--- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
+++ /dev/null
@@ -1,148 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
-
-import com.google.common.io.ByteStreams
-
-import org.mockito.Matchers.any
-import org.mockito.Mockito._
-import org.mockito.Mockito.RETURNS_SMART_NULLS
-import org.mockito.invocation.InvocationOnMock
-import org.mockito.stubbing.Answer
-import org.scalatest.Matchers._
-
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.serializer.KryoSerializer
-import org.apache.spark.storage.DiskBlockObjectWriter
-
-class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
-  test("OrderedInputStream single record") {
-    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
-
-    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
-    val struct = SomeStruct("something", 5)
-    buffer.insert(4, 10, struct)
-
-    val bytes = ByteStreams.toByteArray(buffer.orderedInputStream)
-
-    val baos = new ByteArrayOutputStream()
-    val stream = serializerInstance.serializeStream(baos)
-    stream.writeObject(10)
-    stream.writeObject(struct)
-    stream.close()
-
-    baos.toByteArray should be (bytes)
-  }
-
-  test("insert single record") {
-    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
-    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
-    val struct = SomeStruct("something", 5)
-    buffer.insert(4, 10, struct)
-    val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
-    elements.size should be (1)
-    elements.head should be (((4, 10), struct))
-  }
-
-  test("insert multiple records") {
-    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
-    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
-    val struct1 = SomeStruct("something1", 8)
-    buffer.insert(6, 1, struct1)
-    val struct2 = SomeStruct("something2", 9)
-    buffer.insert(4, 2, struct2)
-    val struct3 = SomeStruct("something3", 10)
-    buffer.insert(5, 3, struct3)
-
-    val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
-    elements.size should be (3)
-    elements(0) should be (((4, 2), struct2))
-    elements(1) should be (((5, 3), struct3))
-    elements(2) should be (((6, 1), struct1))
-  }
-
-  test("write single record") {
-    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
-    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
-    val struct = SomeStruct("something", 5)
-    buffer.insert(4, 10, struct)
-    val it = buffer.destructiveSortedWritablePartitionedIterator(None)
-    val (writer, baos) = createMockWriter()
-    assert(it.hasNext)
-    it.nextPartition should be (4)
-    it.writeNext(writer)
-    assert(!it.hasNext)
-
-    val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
-    stream.readObject[AnyRef]() should be (10)
-    stream.readObject[AnyRef]() should be (struct)
-  }
-
-  test("write multiple records") {
-    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
-    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
-    val struct1 = SomeStruct("something1", 8)
-    buffer.insert(6, 1, struct1)
-    val struct2 = SomeStruct("something2", 9)
-    buffer.insert(4, 2, struct2)
-    val struct3 = SomeStruct("something3", 10)
-    buffer.insert(5, 3, struct3)
-
-    val it = buffer.destructiveSortedWritablePartitionedIterator(None)
-    val (writer, baos) = createMockWriter()
-    assert(it.hasNext)
-    it.nextPartition should be (4)
-    it.writeNext(writer)
-    assert(it.hasNext)
-    it.nextPartition should be (5)
-    it.writeNext(writer)
-    assert(it.hasNext)
-    it.nextPartition should be (6)
-    it.writeNext(writer)
-    assert(!it.hasNext)
-
-    val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
-    val iter = stream.asIterator
-    iter.next() should be (2)
-    iter.next() should be (struct2)
-    iter.next() should be (3)
-    iter.next() should be (struct3)
-    iter.next() should be (1)
-    iter.next() should be (struct1)
-    assert(!iter.hasNext)
-  }
-
-  def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = {
-    val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS)
-    val baos = new ByteArrayOutputStream()
-    when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] {
-      override def answer(invocationOnMock: InvocationOnMock): Unit = {
-        val args = invocationOnMock.getArguments
-        val bytes = args(0).asInstanceOf[Array[Byte]]
-        val offset = args(1).asInstanceOf[Int]
-        val length = args(2).asInstanceOf[Int]
-        baos.write(bytes, offset, length)
-      }
-    })
-    (writer, baos)
-  }
-}
-
-case class SomeStruct(str: String, num: Int)
diff --git a/docs/configuration.md b/docs/configuration.md
index 46d92ceb76..be9c36bdfe 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -437,12 +437,9 @@ Apart from these, the following properties are also available, and may be useful
   <td><code>spark.shuffle.manager</code></td>
   <td>sort</td>
   <td>
-    Implementation to use for shuffling data. There are three implementations available:
-    <code>sort</code>, <code>hash</code> and the new (1.5+) <code>tungsten-sort</code>.
+    Implementation to use for shuffling data. There are two implementations available:
+    <code>sort</code> and <code>hash</code>.
     Sort-based shuffle is more memory-efficient and is the default option starting in 1.2.
-    Tungsten-sort is similar to the sort based shuffle, with a direct binary cache-friendly
-    implementation with a fall back to regular sort based shuffle if its requirements are not
-    met.
   </td>
 </tr>
 <tr>
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 0872d3f3e7..b5e661d3ec 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -37,6 +37,7 @@ object MimaExcludes {
       Seq(
         MimaBuild.excludeSparkPackage("deploy"),
         MimaBuild.excludeSparkPackage("network"),
+        MimaBuild.excludeSparkPackage("unsafe"),
         // These are needed if checking against the sbt build, since they are part of
         // the maven-generated artifacts in 1.3.
         excludePackage("org.spark-project.jetty"),
@@ -44,7 +45,11 @@ object MimaExcludes {
         // SQL execution is considered private.
         excludePackage("org.apache.spark.sql.execution"),
         // SQL columnar is considered private.
-        excludePackage("org.apache.spark.sql.columnar")
+        excludePackage("org.apache.spark.sql.columnar"),
+        // The shuffle package is considered private.
+        excludePackage("org.apache.spark.shuffle"),
+        // The collections utlities are considered pricate.
+        excludePackage("org.apache.spark.util.collection")
       ) ++
       MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++
       MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++
@@ -750,4 +755,4 @@ object MimaExcludes {
       MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD")
     case _ => Seq()
   }
-}
\ No newline at end of file
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 1d3379a5e2..7f60c8f5ea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -24,7 +24,6 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.hash.HashShuffleManager
 import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors.attachTree
@@ -87,10 +86,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
     // fewer partitions (like RangePartitioner, for example).
     val conf = child.sqlContext.sparkContext.conf
     val shuffleManager = SparkEnv.get.shuffleManager
-    val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] ||
-      shuffleManager.isInstanceOf[UnsafeShuffleManager]
+    val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager]
     val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
-    val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true)
     if (sortBasedShuffleOn) {
       val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
       if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) {
@@ -99,22 +96,18 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
         // doesn't buffer deserialized records.
         // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
         false
-      } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) {
-        // SPARK-4550 extended sort-based shuffle to serialize individual records prior to sorting
-        // them. This optimization is guarded by a feature-flag and is only applied in cases where
-        // shuffle dependency does not specify an aggregator or ordering and the record serializer
-        // has certain properties. If this optimization is enabled, we can safely avoid the copy.
+      } else if (serializer.supportsRelocationOfSerializedObjects) {
+        // SPARK-4550 and  SPARK-7081 extended sort-based shuffle to serialize individual records
+        // prior to sorting them. This optimization is only applied in cases where shuffle
+        // dependency does not specify an aggregator or ordering and the record serializer has
+        // certain properties. If this optimization is enabled, we can safely avoid the copy.
         //
         // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only
         // need to check whether the optimization is enabled and supported by our serializer.
-        //
-        // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081).
         false
       } else {
-        // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code
-        // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls
-        // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In
-        // both cases, we must copy.
+        // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must
+        // copy.
         true
       }
     } else if (shuffleManager.isInstanceOf[HashShuffleManager]) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 75d1fced59..1680d7e0a8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -101,7 +101,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
     val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten
     Utils.tryWithSafeFinally {
       val conf = new SparkConf()
-        .set("spark.shuffle.spill.initialMemoryThreshold", "1024")
+        .set("spark.shuffle.spill.initialMemoryThreshold", "1")
         .set("spark.shuffle.sort.bypassMergeThreshold", "0")
         .set("spark.testing.memory", "80000")
 
@@ -109,7 +109,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
       outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
       // prepare data
       val converter = unsafeRowConverter(Array(IntegerType))
-      val data = (1 to 1000).iterator.map { i =>
+      val data = (1 to 10000).iterator.map { i =>
         (i, converter(Row(i)))
       }
       val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
@@ -141,9 +141,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
     }
   }
 
-  test("SPARK-10403: unsafe row serializer with UnsafeShuffleManager") {
-    val conf = new SparkConf()
-      .set("spark.shuffle.manager", "tungsten-sort")
+  test("SPARK-10403: unsafe row serializer with SortShuffleManager") {
+    val conf = new SparkConf().set("spark.shuffle.manager", "sort")
     sc = new SparkContext("local", "test", conf)
     val row = Row("Hello", 123)
     val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
-- 
GitLab