diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
new file mode 100644
index 0000000000000000000000000000000000000000..d3d6280284bebcbe170a264662b416d44b685581
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+import scala.Product2;
+import scala.Tuple2;
+import scala.collection.Iterator;
+
+import com.google.common.io.Closeables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.Partitioner;
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.storage.*;
+import org.apache.spark.util.Utils;
+
+/**
+ * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path
+ * writes incoming records to separate files, one file per reduce partition, then concatenates these
+ * per-partition files to form a single output file, regions of which are served to reducers.
+ * Records are not buffered in memory. This is essentially identical to
+ * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format
+ * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}.
+ * <p>
+ * This write path is inefficient for shuffles with large numbers of reduce partitions because it
+ * simultaneously opens separate serializers and file streams for all partitions. As a result,
+ * {@link SortShuffleManager} only selects this write path when
+ * <ul>
+ *    <li>no Ordering is specified,</li>
+ *    <li>no Aggregator is specific, and</li>
+ *    <li>the number of partitions is less than
+ *      <code>spark.shuffle.sort.bypassMergeThreshold</code>.</li>
+ * </ul>
+ *
+ * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was
+ * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details.
+ * <p>
+ * There have been proposals to completely remove this code path; see SPARK-6026 for details.
+ */
+final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<K, V> {
+
+  private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);
+
+  private final int fileBufferSize;
+  private final boolean transferToEnabled;
+  private final int numPartitions;
+  private final BlockManager blockManager;
+  private final Partitioner partitioner;
+  private final ShuffleWriteMetrics writeMetrics;
+  private final Serializer serializer;
+
+  /** Array of file writers, one for each partition */
+  private BlockObjectWriter[] partitionWriters;
+
+  public BypassMergeSortShuffleWriter(
+      SparkConf conf,
+      BlockManager blockManager,
+      Partitioner partitioner,
+      ShuffleWriteMetrics writeMetrics,
+      Serializer serializer) {
+    // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+    this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+    this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
+    this.numPartitions = partitioner.numPartitions();
+    this.blockManager = blockManager;
+    this.partitioner = partitioner;
+    this.writeMetrics = writeMetrics;
+    this.serializer = serializer;
+  }
+
+  @Override
+  public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
+    assert (partitionWriters == null);
+    if (!records.hasNext()) {
+      return;
+    }
+    final SerializerInstance serInstance = serializer.newInstance();
+    final long openStartTime = System.nanoTime();
+    partitionWriters = new BlockObjectWriter[numPartitions];
+    for (int i = 0; i < numPartitions; i++) {
+      final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
+        blockManager.diskBlockManager().createTempShuffleBlock();
+      final File file = tempShuffleBlockIdPlusFile._2();
+      final BlockId blockId = tempShuffleBlockIdPlusFile._1();
+      partitionWriters[i] =
+        blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open();
+    }
+    // Creating the file to write to and creating a disk writer both involve interacting with
+    // the disk, and can take a long time in aggregate when we open many files, so should be
+    // included in the shuffle write time.
+    writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime);
+
+    while (records.hasNext()) {
+      final Product2<K, V> record = records.next();
+      final K key = record._1();
+      partitionWriters[partitioner.getPartition(key)].write(key, record._2());
+    }
+
+    for (BlockObjectWriter writer : partitionWriters) {
+      writer.commitAndClose();
+    }
+  }
+
+  @Override
+  public long[] writePartitionedFile(
+      BlockId blockId,
+      TaskContext context,
+      File outputFile) throws IOException {
+    // Track location of the partition starts in the output file
+    final long[] lengths = new long[numPartitions];
+    if (partitionWriters == null) {
+      // We were passed an empty iterator
+      return lengths;
+    }
+
+    final FileOutputStream out = new FileOutputStream(outputFile, true);
+    final long writeStartTime = System.nanoTime();
+    boolean threwException = true;
+    try {
+      for (int i = 0; i < numPartitions; i++) {
+        final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file());
+        boolean copyThrewException = true;
+        try {
+          lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
+          copyThrewException = false;
+        } finally {
+          Closeables.close(in, copyThrewException);
+        }
+        if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) {
+          logger.error("Unable to delete file for partition {}", i);
+        }
+      }
+      threwException = false;
+    } finally {
+      Closeables.close(out, threwException);
+      writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
+    }
+    partitionWriters = null;
+    return lengths;
+  }
+
+  @Override
+  public void stop() throws IOException {
+    if (partitionWriters != null) {
+      try {
+        final DiskBlockManager diskBlockManager = blockManager.diskBlockManager();
+        for (BlockObjectWriter writer : partitionWriters) {
+          // This method explicitly does _not_ throw exceptions:
+          writer.revertPartialWritesAndClose();
+          if (!diskBlockManager.getFile(writer.blockId()).delete()) {
+            logger.error("Error while deleting file for block {}", writer.blockId());
+          }
+        }
+      } finally {
+        partitionWriters = null;
+      }
+    }
+  }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
new file mode 100644
index 0000000000000000000000000000000000000000..656ea0401a144fd8e51968a13d3cb27d5b352fd3
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+import java.io.IOException;
+
+import scala.Product2;
+import scala.collection.Iterator;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.TaskContext;
+import org.apache.spark.storage.BlockId;
+
+/**
+ * Interface for objects that {@link SortShuffleWriter} uses to write its output files.
+ */
+@Private
+public interface SortShuffleFileWriter<K, V> {
+
+  void insertAll(Iterator<Product2<K, V>> records) throws IOException;
+
+  /**
+   * Write all the data added into this shuffle sorter into a file in the disk store. This is
+   * called by the SortShuffleWriter and can go through an efficient path of just concatenating
+   * binary files if we decided to avoid merge-sorting.
+   *
+   * @param blockId block ID to write to. The index file will be blockId.name + ".index".
+   * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
+   * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
+   */
+  long[] writePartitionedFile(
+      BlockId blockId,
+      TaskContext context,
+      File outputFile) throws IOException;
+
+  void stop() throws IOException;
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index c9dd6bfc4c2199a6cebd13ec5d162d94cf5f9b14..5865e7640c1cfc2475f13e6b989fd46b89719159 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.shuffle.sort
 
-import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
+import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
 import org.apache.spark.storage.ShuffleBlockId
 import org.apache.spark.util.collection.ExternalSorter
@@ -35,7 +36,7 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   private val blockManager = SparkEnv.get.blockManager
 
-  private var sorter: ExternalSorter[K, V, _] = null
+  private var sorter: SortShuffleFileWriter[K, V] = null
 
   // Are we in the process of stopping? Because map tasks can call stop() with success = true
   // and then call stop() with success = false if they get an exception, we want to make sure
@@ -49,18 +50,27 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   /** Write a bunch of records to this task's output */
   override def write(records: Iterator[Product2[K, V]]): Unit = {
-    if (dep.mapSideCombine) {
+    sorter = if (dep.mapSideCombine) {
       require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
-      sorter = new ExternalSorter[K, V, C](
+      new ExternalSorter[K, V, C](
         dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
-      sorter.insertAll(records)
+    } else if (SortShuffleWriter.shouldBypassMergeSort(
+        SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) {
+      // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
+      // need local aggregation and sorting, write numPartitions files directly and just concatenate
+      // them at the end. This avoids doing serialization and deserialization twice to merge
+      // together the spilled files, which would happen with the normal code path. The downside is
+      // having multiple files open at a time and thus more memory allocated to buffers.
+      new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner,
+        writeMetrics, Serializer.getSerializer(dep.serializer))
     } else {
       // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
       // care whether the keys get sorted in each partition; that will be done on the reduce side
       // if the operation being run is sortByKey.
-      sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer)
-      sorter.insertAll(records)
+      new ExternalSorter[K, V, V](
+        aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
     }
+    sorter.insertAll(records)
 
     // Don't bother including the time to open the merged output file in the shuffle write time,
     // because it just opens a single file, so is typically too fast to measure accurately
@@ -100,3 +110,13 @@ private[spark] class SortShuffleWriter[K, V, C](
   }
 }
 
+private[spark] object SortShuffleWriter {
+  def shouldBypassMergeSort(
+      conf: SparkConf,
+      numPartitions: Int,
+      aggregator: Option[Aggregator[_, _, _]],
+      keyOrdering: Option[Ordering[_]]): Boolean = {
+    val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+    numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index a33f22ef52687bddea0e430ed489a9f49af79e4d..7eeabd1e0489ce49aef486cdf3bcb860af14c9c1 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -95,6 +95,7 @@ private[spark] class DiskBlockObjectWriter(
   private var objOut: SerializationStream = null
   private var initialized = false
   private var hasBeenClosed = false
+  private var commitAndCloseHasBeenCalled = false
 
   /**
    * Cursors used to represent positions in the file.
@@ -167,20 +168,22 @@ private[spark] class DiskBlockObjectWriter(
       objOut.flush()
       bs.flush()
       close()
+      finalPosition = file.length()
+      // In certain compression codecs, more bytes are written after close() is called
+      writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition)
+    } else {
+      finalPosition = file.length()
     }
-    finalPosition = file.length()
-    // In certain compression codecs, more bytes are written after close() is called
-    writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition)
+    commitAndCloseHasBeenCalled = true
   }
 
   // Discard current writes. We do this by flushing the outstanding writes and then
   // truncating the file to its initial position.
   override def revertPartialWritesAndClose() {
     try {
-      writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
-      writeMetrics.decShuffleRecordsWritten(numRecordsWritten)
-
       if (initialized) {
+        writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
+        writeMetrics.decShuffleRecordsWritten(numRecordsWritten)
         objOut.flush()
         bs.flush()
         close()
@@ -228,6 +231,10 @@ private[spark] class DiskBlockObjectWriter(
   }
 
   override def fileSegment(): FileSegment = {
+    if (!commitAndCloseHasBeenCalled) {
+      throw new IllegalStateException(
+        "fileSegment() is only valid after commitAndClose() has been called")
+    }
     new FileSegment(file, initialPosition, finalPosition - initialPosition)
   }
 
diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
index 95e2d688d9b178a49a2ebfc7e11e227d1ba98b1c..021a9facfb0b2df399e086115511e4c927870ac9 100644
--- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
+++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
@@ -24,6 +24,8 @@ import java.io.File
  * based off an offset and a length.
  */
 private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) {
+  require(offset >= 0, s"File segment offset cannot be negative (got $offset)")
+  require(length >= 0, s"File segment length cannot be negative (got $length)")
   override def toString: String = {
     "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
   }
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 3b9d14f9372b618daf2b7d15de148e33fe9e1049..ef2dbb7ff0ae0720ce8a14fc1a46077c2341959c 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -23,12 +23,14 @@ import java.util.Comparator
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable
 
+import com.google.common.annotations.VisibleForTesting
 import com.google.common.io.ByteStreams
 
 import org.apache.spark._
 import org.apache.spark.serializer._
 import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.storage.{BlockObjectWriter, BlockId}
+import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter}
+import org.apache.spark.storage.{BlockId, BlockObjectWriter}
 
 /**
  * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -84,35 +86,40 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId}
  *   each other for equality to merge values.
  *
  * - Users are expected to call stop() at the end to delete all the intermediate files.
- *
- * As a special case, if no Ordering and no Aggregator is given, and the number of partitions is
- * less than spark.shuffle.sort.bypassMergeThreshold, we bypass the merge-sort and just write to
- * separate files for each partition each time we spill, similar to the HashShuffleWriter. We can
- * then concatenate these files to produce a single sorted file, without having to serialize and
- * de-serialize each item twice (as is needed during the merge). This speeds up the map side of
- * groupBy, sort, etc operations since they do no partial aggregation.
  */
 private[spark] class ExternalSorter[K, V, C](
     aggregator: Option[Aggregator[K, V, C]] = None,
     partitioner: Option[Partitioner] = None,
     ordering: Option[Ordering[K]] = None,
     serializer: Option[Serializer] = None)
-  extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] {
+  extends Logging
+  with Spillable[WritablePartitionedPairCollection[K, C]]
+  with SortShuffleFileWriter[K, V] {
+
+  private val conf = SparkEnv.get.conf
 
   private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
   private val shouldPartition = numPartitions > 1
+  private def getPartition(key: K): Int = {
+    if (shouldPartition) partitioner.get.getPartition(key) else 0
+  }
+
+  // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class.
+  // As a sanity check, make sure that we're not handling a shuffle which should use that path.
+  if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) {
+    throw new IllegalArgumentException("ExternalSorter should not be used to handle "
+      + " a sort that the BypassMergeSortShuffleWriter should handle")
+  }
 
   private val blockManager = SparkEnv.get.blockManager
   private val diskBlockManager = blockManager.diskBlockManager
   private val ser = Serializer.getSerializer(serializer)
   private val serInstance = ser.newInstance()
 
-  private val conf = SparkEnv.get.conf
   private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
   
   // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
   private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
-  private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true)
 
   // Size of object batches when reading/writing from serializers.
   //
@@ -123,43 +130,28 @@ private[spark] class ExternalSorter[K, V, C](
   // grow internal data structures by growing + copying every time the number of objects doubles.
   private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
 
-  private def getPartition(key: K): Int = {
-    if (shouldPartition) partitioner.get.getPartition(key) else 0
-  }
-
-  private val metaInitialRecords = 256
-  private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
   private val useSerializedPairBuffer =
-    !ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
-    ser.supportsRelocationOfSerializedObjects
-
+    ordering.isEmpty &&
+      conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
+      ser.supportsRelocationOfSerializedObjects
+  private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
+  private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = {
+    if (useSerializedPairBuffer) {
+      new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance)
+    } else {
+      new PartitionedPairBuffer[K, C]
+    }
+  }
   // Data structures to store in-memory objects before we spill. Depending on whether we have an
   // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
   // store them in an array buffer.
   private var map = new PartitionedAppendOnlyMap[K, C]
-  private var buffer = if (useSerializedPairBuffer) {
-    new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
-  } else {
-    new PartitionedPairBuffer[K, C]
-  }
+  private var buffer = newBuffer()
 
   // Total spilling statistics
   private var _diskBytesSpilled = 0L
+  def diskBytesSpilled: Long = _diskBytesSpilled
 
-  // Write metrics for current spill
-  private var curWriteMetrics: ShuffleWriteMetrics = _
-
-  // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need
-  // local aggregation and sorting, write numPartitions files directly and just concatenate them
-  // at the end. This avoids doing serialization and deserialization twice to merge together the
-  // spilled files, which would happen with the normal code path. The downside is having multiple
-  // files open at a time and thus more memory allocated to buffers.
-  private val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
-  private val bypassMergeSort =
-    (numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty)
-
-  // Array of file writers for each partition, used if bypassMergeSort is true and we've spilled
-  private var partitionWriters: Array[BlockObjectWriter] = null
 
   // A comparator for keys K that orders them within a partition to allow aggregation or sorting.
   // Can be a partial ordering by hash code if a total ordering is not provided through by the
@@ -174,6 +166,14 @@ private[spark] class ExternalSorter[K, V, C](
     }
   })
 
+  private def comparator: Option[Comparator[K]] = {
+    if (ordering.isDefined || aggregator.isDefined) {
+      Some(keyComparator)
+    } else {
+      None
+    }
+  }
+
   // Information about a spilled file. Includes sizes in bytes of "batches" written by the
   // serializer as we periodically reset its stream, as well as number of elements in each
   // partition, used to efficiently keep track of partitions when merging.
@@ -182,9 +182,10 @@ private[spark] class ExternalSorter[K, V, C](
     blockId: BlockId,
     serializerBatchSizes: Array[Long],
     elementsPerPartition: Array[Long])
+
   private val spills = new ArrayBuffer[SpilledFile]
 
-  def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = {
+  override def insertAll(records: Iterator[Product2[K, V]]): Unit = {
     // TODO: stop combining if we find that the reduction factor isn't high
     val shouldCombine = aggregator.isDefined
 
@@ -202,15 +203,6 @@ private[spark] class ExternalSorter[K, V, C](
         map.changeValue((getPartition(kv._1), kv._1), update)
         maybeSpillCollection(usingMap = true)
       }
-    } else if (bypassMergeSort) {
-      // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
-      if (records.hasNext) {
-        spillToPartitionFiles(
-          WritablePartitionedIterator.fromIterator(records.map { kv =>
-            ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
-          })
-        )
-      }
     } else {
       // Stick values into our buffer
       while (records.hasNext) {
@@ -238,46 +230,33 @@ private[spark] class ExternalSorter[K, V, C](
       }
     } else {
       if (maybeSpill(buffer, buffer.estimateSize())) {
-        buffer = if (useSerializedPairBuffer) {
-          new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
-        } else {
-          new PartitionedPairBuffer[K, C]
-        }
+        buffer = newBuffer()
       }
     }
   }
 
   /**
-   * Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
-   */
-  override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
-    if (bypassMergeSort) {
-      spillToPartitionFiles(collection)
-    } else {
-      spillToMergeableFile(collection)
-    }
-  }
-
-  /**
-   * Spill our in-memory collection to a sorted file that we can merge later (normal code path).
-   * We add this file into spilledFiles to find it later.
-   *
-   * This should not be invoked if bypassMergeSort is true. In that case, spillToPartitionedFiles()
-   * is used to write files for each partition.
+   * Spill our in-memory collection to a sorted file that we can merge later.
+   * We add this file into `spilledFiles` to find it later.
    *
    * @param collection whichever collection we're using (map or buffer)
    */
-  private def spillToMergeableFile(collection: WritablePartitionedPairCollection[K, C]): Unit = {
-    assert(!bypassMergeSort)
-
+  override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
     // Because these files may be read during shuffle, their compression must be controlled by
     // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
     // createTempShuffleBlock here; see SPARK-3426 for more context.
     val (blockId, file) = diskBlockManager.createTempShuffleBlock()
-    curWriteMetrics = new ShuffleWriteMetrics()
-    var writer = blockManager.getDiskWriter(
-      blockId, file, serInstance, fileBufferSize, curWriteMetrics)
-    var objectsWritten = 0   // Objects written since the last flush
+
+    // These variables are reset after each flush
+    var objectsWritten: Long = 0
+    var spillMetrics: ShuffleWriteMetrics = null
+    var writer: BlockObjectWriter = null
+    def openWriter(): Unit = {
+      assert (writer == null && spillMetrics == null)
+      spillMetrics = new ShuffleWriteMetrics
+      writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
+    }
+    openWriter()
 
     // List of batch sizes (bytes) in the order they are written to disk
     val batchSizes = new ArrayBuffer[Long]
@@ -291,8 +270,9 @@ private[spark] class ExternalSorter[K, V, C](
       val w = writer
       writer = null
       w.commitAndClose()
-      _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
-      batchSizes.append(curWriteMetrics.shuffleBytesWritten)
+      _diskBytesSpilled += spillMetrics.shuffleBytesWritten
+      batchSizes.append(spillMetrics.shuffleBytesWritten)
+      spillMetrics = null
       objectsWritten = 0
     }
 
@@ -307,9 +287,7 @@ private[spark] class ExternalSorter[K, V, C](
 
         if (objectsWritten == serializerBatchSize) {
           flush()
-          curWriteMetrics = new ShuffleWriteMetrics()
-          writer = blockManager.getDiskWriter(
-            blockId, file, serInstance, fileBufferSize, curWriteMetrics)
+          openWriter()
         }
       }
       if (objectsWritten > 0) {
@@ -336,46 +314,6 @@ private[spark] class ExternalSorter[K, V, C](
     spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
   }
 
-  /**
-   * Spill our in-memory collection to separate files, one for each partition. This is used when
-   * there's no aggregator and ordering and the number of partitions is small, because it allows
-   * writePartitionedFile to just concatenate files without deserializing data.
-   *
-   * @param collection whichever collection we're using (map or buffer)
-   */
-  private def spillToPartitionFiles(collection: WritablePartitionedPairCollection[K, C]): Unit = {
-    spillToPartitionFiles(collection.writablePartitionedIterator())
-  }
-
-  private def spillToPartitionFiles(iterator: WritablePartitionedIterator): Unit = {
-    assert(bypassMergeSort)
-
-    // Create our file writers if we haven't done so yet
-    if (partitionWriters == null) {
-      curWriteMetrics = new ShuffleWriteMetrics()
-      val openStartTime = System.nanoTime
-      partitionWriters = Array.fill(numPartitions) {
-        // Because these files may be read during shuffle, their compression must be controlled by
-        // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
-        // createTempShuffleBlock here; see SPARK-3426 for more context.
-        val (blockId, file) = diskBlockManager.createTempShuffleBlock()
-        val writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize,
-          curWriteMetrics)
-        writer.open()
-      }
-      // Creating the file to write to and creating a disk writer both involve interacting with
-      // the disk, and can take a long time in aggregate when we open many files, so should be
-      // included in the shuffle write time.
-      curWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime)
-    }
-
-    // No need to sort stuff, just write each element out
-    while (iterator.hasNext) {
-      val partitionId = iterator.nextPartition()
-      iterator.writeNext(partitionWriters(partitionId))
-    }
-  }
-
   /**
    * Merge a sequence of sorted files, giving an iterator over partitions and then over elements
    * inside each partition. This can be used to either write out a new file or return data to
@@ -665,8 +603,6 @@ private[spark] class ExternalSorter[K, V, C](
   }
 
   /**
-   * Exposed for testing purposes.
-   *
    * Return an iterator over all the data written to this object, grouped by partition and
    * aggregated by the requested aggregator. For each partition we then have an iterator over its
    * contents, and these are expected to be accessed in order (you can't "skip ahead" to one
@@ -676,10 +612,11 @@ private[spark] class ExternalSorter[K, V, C](
    * For now, we just merge all the spilled files in once pass, but this can be modified to
    * support hierarchical merging.
    */
-   def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
+  @VisibleForTesting
+  def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
     val usingMap = aggregator.isDefined
     val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
-    if (spills.isEmpty && partitionWriters == null) {
+    if (spills.isEmpty) {
       // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
       // we don't even need to sort by anything other than partition ID
       if (!ordering.isDefined) {
@@ -689,13 +626,6 @@ private[spark] class ExternalSorter[K, V, C](
         // We do need to sort by both partition ID and key
         groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator)))
       }
-    } else if (bypassMergeSort) {
-      // Read data from each partition file and merge it together with the data in memory;
-      // note that there's no ordering or aggregator in this case -- we just partition objects
-      val collIter = groupByPartition(collection.partitionedDestructiveSortedIterator(None))
-      collIter.map { case (partitionId, values) =>
-        (partitionId, values ++ readPartitionFile(partitionWriters(partitionId)))
-      }
     } else {
       // Merge spilled and in-memory data
       merge(spills, collection.partitionedDestructiveSortedIterator(comparator))
@@ -709,14 +639,13 @@ private[spark] class ExternalSorter[K, V, C](
 
   /**
    * Write all the data added into this ExternalSorter into a file in the disk store. This is
-   * called by the SortShuffleWriter and can go through an efficient path of just concatenating
-   * binary files if we decided to avoid merge-sorting.
+   * called by the SortShuffleWriter.
    *
    * @param blockId block ID to write to. The index file will be blockId.name + ".index".
    * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
    * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
    */
-  def writePartitionedFile(
+  override def writePartitionedFile(
       blockId: BlockId,
       context: TaskContext,
       outputFile: File): Array[Long] = {
@@ -724,28 +653,7 @@ private[spark] class ExternalSorter[K, V, C](
     // Track location of each range in the output file
     val lengths = new Array[Long](numPartitions)
 
-    if (bypassMergeSort && partitionWriters != null) {
-      // We decided to write separate files for each partition, so just concatenate them. To keep
-      // this simple we spill out the current in-memory collection so that everything is in files.
-      spillToPartitionFiles(if (aggregator.isDefined) map else buffer)
-      partitionWriters.foreach(_.commitAndClose())
-      val out = new FileOutputStream(outputFile, true)
-      val writeStartTime = System.nanoTime
-      util.Utils.tryWithSafeFinally {
-        for (i <- 0 until numPartitions) {
-          val in = new FileInputStream(partitionWriters(i).fileSegment().file)
-          util.Utils.tryWithSafeFinally {
-            lengths(i) = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled)
-          } {
-            in.close()
-          }
-        }
-      } {
-        out.close()
-        context.taskMetrics.shuffleWriteMetrics.foreach(
-          _.incShuffleWriteTime(System.nanoTime - writeStartTime))
-      }
-    } else if (spills.isEmpty && partitionWriters == null) {
+    if (spills.isEmpty) {
       // Case where we only have in-memory data
       val collection = if (aggregator.isDefined) map else buffer
       val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
@@ -761,7 +669,7 @@ private[spark] class ExternalSorter[K, V, C](
         lengths(partitionId) = segment.length
       }
     } else {
-      // Not bypassing merge-sort; get an iterator by partition and just write everything directly.
+      // We must perform merge-sort; get an iterator by partition and write everything directly.
       for ((id, elements) <- this.partitionedIterator) {
         if (elements.hasNext) {
           val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
@@ -778,41 +686,15 @@ private[spark] class ExternalSorter[K, V, C](
 
     context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
     context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
-    context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m =>
-      if (curWriteMetrics != null) {
-        m.incShuffleBytesWritten(curWriteMetrics.shuffleBytesWritten)
-        m.incShuffleWriteTime(curWriteMetrics.shuffleWriteTime)
-        m.incShuffleRecordsWritten(curWriteMetrics.shuffleRecordsWritten)
-      }
-    }
 
     lengths
   }
 
-  /**
-   * Read a partition file back as an iterator (used in our iterator method)
-   */
-  private def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = {
-    if (writer.isOpen) {
-      writer.commitAndClose()
-    }
-    new PairIterator[K, C](blockManager.diskStore.getValues(writer.blockId, ser).get)
-  }
-
   def stop(): Unit = {
     spills.foreach(s => s.file.delete())
     spills.clear()
-    if (partitionWriters != null) {
-      partitionWriters.foreach { w =>
-        w.revertPartialWritesAndClose()
-        diskBlockManager.getFile(w.blockId).delete()
-      }
-      partitionWriters = null
-    }
   }
 
-  def diskBytesSpilled: Long = _diskBytesSpilled
-
   /**
    * Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*,
    * group together the pairs for each partition into a sub-iterator.
@@ -826,14 +708,6 @@ private[spark] class ExternalSorter[K, V, C](
     (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered)))
   }
 
-  private def comparator: Option[Comparator[K]] = {
-    if (ordering.isDefined || aggregator.isDefined) {
-      Some(keyComparator)
-    } else {
-      None
-    }
-  }
-
   /**
    * An iterator that reads only the elements for a given partition ID from an underlying buffered
    * stream, assuming this partition is the next one to be read. Used to make it easier to return
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
deleted file mode 100644
index d75959f480756f8a057640299d7b4eb158639fdc..0000000000000000000000000000000000000000
--- a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
+++ /dev/null
@@ -1,24 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] {
-  def hasNext: Boolean = iter.hasNext
-
-  def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V])
-}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
index e2e2f1faae9d1a54240128ad5d24eca5b362fc19..d0d25b43d047752e7d3685797dbab96a1348f134 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
@@ -34,10 +34,6 @@ private[spark] class PartitionedAppendOnlyMap[K, V]
     destructiveSortedIterator(comparator)
   }
 
-  def writablePartitionedIterator(): WritablePartitionedIterator = {
-    WritablePartitionedIterator.fromIterator(super.iterator)
-  }
-
   def insert(partition: Int, key: K, value: V): Unit = {
     update((partition, key), value)
   }
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
index e8332e1a87eacf16625970e0ab71ac75df8d9154..5a6e9a9580e9bbca554bbde7a402acc42933de33 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
@@ -71,10 +71,6 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64)
     iterator
   }
 
-  override def writablePartitionedIterator(): WritablePartitionedIterator = {
-    WritablePartitionedIterator.fromIterator(iterator)
-  }
-
   private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] {
     var pos = 0
 
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
index 554d88206e22124832c858d7b39df620b54b62f3..862408b7a4d21d5a2d101f754b7d516e752f358c 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
@@ -122,10 +122,6 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
   override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
     : WritablePartitionedIterator = {
     sort(keyComparator)
-    writablePartitionedIterator
-  }
-
-  override def writablePartitionedIterator(): WritablePartitionedIterator = {
     new WritablePartitionedIterator {
       // current position in the meta buffer in ints
       var pos = 0
diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
index f26d1618c9200ab7d1665c1970e60f4788e21c95..7bc59898658e48648da949a1c8f227fc5bfeb0e5 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -47,13 +47,20 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
    */
   def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
     : WritablePartitionedIterator = {
-    WritablePartitionedIterator.fromIterator(partitionedDestructiveSortedIterator(keyComparator))
-  }
+    val it = partitionedDestructiveSortedIterator(keyComparator)
+    new WritablePartitionedIterator {
+      private[this] var cur = if (it.hasNext) it.next() else null
 
-  /**
-   * Iterate through the data and write out the elements instead of returning them.
-   */
-  def writablePartitionedIterator(): WritablePartitionedIterator
+      def writeNext(writer: BlockObjectWriter): Unit = {
+        writer.write(cur._1._2, cur._2)
+        cur = if (it.hasNext) it.next() else null
+      }
+
+      def hasNext(): Boolean = cur != null
+
+      def nextPartition(): Int = cur._1._1
+    }
+  }
 }
 
 private[spark] object WritablePartitionedPairCollection {
@@ -94,20 +101,3 @@ private[spark] trait WritablePartitionedIterator {
 
   def nextPartition(): Int
 }
-
-private[spark] object WritablePartitionedIterator {
-  def fromIterator(it: Iterator[((Int, _), _)]): WritablePartitionedIterator = {
-    new WritablePartitionedIterator {
-      var cur = if (it.hasNext) it.next() else null
-
-      def writeNext(writer: BlockObjectWriter): Unit = {
-        writer.write(cur._1._2, cur._2)
-        cur = if (it.hasNext) it.next() else null
-      }
-
-      def hasNext(): Boolean = cur != null
-
-      def nextPartition(): Int = cur._1._1
-    }
-  }
-}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 91f4ab360857e11cce9dff179fadf53353d9d555..c3c2b1ffc1efacbcb7d513352879a7436dc4eb93 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.Matchers
 
 import org.apache.spark.ShuffleSuite.NonJavaSerializableClass
 import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
 import org.apache.spark.serializer.KryoSerializer
 import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId}
 import org.apache.spark.util.MutablePair
@@ -281,6 +282,39 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
     // This count should retry the execution of the previous stage and rerun shuffle.
     rdd.count()
   }
+
+  test("metrics for shuffle without aggregation") {
+    sc = new SparkContext("local", "test", conf.clone())
+    val numRecords = 10000
+
+    val metrics = ShuffleSuite.runAndReturnMetrics(sc) {
+      sc.parallelize(1 to numRecords, 4)
+        .map(key => (key, 1))
+        .groupByKey()
+        .collect()
+    }
+
+    assert(metrics.recordsRead === numRecords)
+    assert(metrics.recordsWritten === numRecords)
+    assert(metrics.bytesWritten === metrics.byresRead)
+    assert(metrics.bytesWritten > 0)
+  }
+
+  test("metrics for shuffle with aggregation") {
+    sc = new SparkContext("local", "test", conf.clone())
+    val numRecords = 10000
+
+    val metrics = ShuffleSuite.runAndReturnMetrics(sc) {
+      sc.parallelize(1 to numRecords, 4)
+        .flatMap(key => Array.fill(100)((key, 1)))
+        .countByKey()
+    }
+
+    assert(metrics.recordsRead === numRecords)
+    assert(metrics.recordsWritten === numRecords)
+    assert(metrics.bytesWritten === metrics.byresRead)
+    assert(metrics.bytesWritten > 0)
+  }
 }
 
 object ShuffleSuite {
@@ -294,4 +328,35 @@ object ShuffleSuite {
       value - o.value
     }
   }
+
+  case class AggregatedShuffleMetrics(
+    recordsWritten: Long,
+    recordsRead: Long,
+    bytesWritten: Long,
+    byresRead: Long)
+
+  def runAndReturnMetrics(sc: SparkContext)(job: => Unit): AggregatedShuffleMetrics = {
+    @volatile var recordsWritten: Long = 0
+    @volatile var recordsRead: Long = 0
+    @volatile var bytesWritten: Long = 0
+    @volatile var bytesRead: Long = 0
+    val listener = new SparkListener {
+      override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+        taskEnd.taskMetrics.shuffleWriteMetrics.foreach { m =>
+          recordsWritten += m.shuffleRecordsWritten
+          bytesWritten += m.shuffleBytesWritten
+        }
+        taskEnd.taskMetrics.shuffleReadMetrics.foreach { m =>
+          recordsRead += m.recordsRead
+          bytesRead += m.totalBytesRead
+        }
+      }
+    }
+    sc.addSparkListener(listener)
+
+    job
+
+    sc.listenerBus.waitUntilEmpty(500)
+    AggregatedShuffleMetrics(recordsWritten, recordsRead, bytesWritten, bytesRead)
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
index 19f1af0dcd461b8aaf18453ff859c34230ca0dc4..9e4d34fb7d3820ec3d95478082fd1df6319efcce 100644
--- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
@@ -193,26 +193,6 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
     assert(records == numRecords)
   }
 
-  test("shuffle records read metrics") {
-    val recordsRead = runAndReturnShuffleRecordsRead {
-      sc.textFile(tmpFilePath, 4)
-        .map(key => (key, 1))
-        .groupByKey()
-        .collect()
-    }
-    assert(recordsRead == numRecords)
-  }
-
-  test("shuffle records written metrics") {
-    val recordsWritten = runAndReturnShuffleRecordsWritten {
-      sc.textFile(tmpFilePath, 4)
-        .map(key => (key, 1))
-        .groupByKey()
-        .collect()
-    }
-    assert(recordsWritten == numRecords)
-  }
-
   /**
    * Tests the metrics from end to end.
    * 1) reading a hadoop file
@@ -301,14 +281,6 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
     runAndReturnMetrics(job, _.taskMetrics.outputMetrics.map(_.recordsWritten))
   }
 
-  private def runAndReturnShuffleRecordsRead(job: => Unit): Long = {
-    runAndReturnMetrics(job, _.taskMetrics.shuffleReadMetrics.map(_.recordsRead))
-  }
-
-  private def runAndReturnShuffleRecordsWritten(job: => Unit): Long = {
-    runAndReturnMetrics(job, _.taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten))
-  }
-
   private def runAndReturnMetrics(job: => Unit,
       collector: (SparkListenerTaskEnd) => Option[Long]): Long = {
     val taskMetrics = new ArrayBuffer[Long]()
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..c8420db6126c00cf47a0ca725192e214ca7757fe
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.File
+import java.util.UUID
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{BeforeAndAfterEach, FunSuite}
+
+import org.apache.spark._
+import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics}
+import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer}
+import org.apache.spark.storage._
+import org.apache.spark.util.Utils
+
+class BypassMergeSortShuffleWriterSuite extends FunSuite with BeforeAndAfterEach {
+
+  @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _
+
+  private var taskMetrics: TaskMetrics = _
+  private var shuffleWriteMetrics: ShuffleWriteMetrics = _
+  private var tempDir: File = _
+  private var outputFile: File = _
+  private val conf: SparkConf = new SparkConf(loadDefaults = false)
+  private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]()
+  private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File]
+  private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0)
+  private val serializer: Serializer = new JavaSerializer(conf)
+
+  override def beforeEach(): Unit = {
+    tempDir = Utils.createTempDir()
+    outputFile = File.createTempFile("shuffle", null, tempDir)
+    shuffleWriteMetrics = new ShuffleWriteMetrics
+    taskMetrics = new TaskMetrics
+    taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
+    MockitoAnnotations.initMocks(this)
+    when(taskContext.taskMetrics()).thenReturn(taskMetrics)
+    when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
+    when(blockManager.getDiskWriter(
+      any[BlockId],
+      any[File],
+      any[SerializerInstance],
+      anyInt(),
+      any[ShuffleWriteMetrics]
+    )).thenAnswer(new Answer[BlockObjectWriter] {
+      override def answer(invocation: InvocationOnMock): BlockObjectWriter = {
+        val args = invocation.getArguments
+        new DiskBlockObjectWriter(
+          args(0).asInstanceOf[BlockId],
+          args(1).asInstanceOf[File],
+          args(2).asInstanceOf[SerializerInstance],
+          args(3).asInstanceOf[Int],
+          compressStream = identity,
+          syncWrites = false,
+          args(4).asInstanceOf[ShuffleWriteMetrics]
+        )
+      }
+    })
+    when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
+      new Answer[(TempShuffleBlockId, File)] {
+        override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = {
+          val blockId = new TempShuffleBlockId(UUID.randomUUID)
+          val file = File.createTempFile(blockId.toString, null, tempDir)
+          blockIdToFileMap.put(blockId, file)
+          temporaryFilesCreated.append(file)
+          (blockId, file)
+        }
+      })
+    when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
+      new Answer[File] {
+        override def answer(invocation: InvocationOnMock): File = {
+          blockIdToFileMap.get(invocation.getArguments.head.asInstanceOf[BlockId]).get
+        }
+    })
+  }
+
+  override def afterEach(): Unit = {
+    Utils.deleteRecursively(tempDir)
+    blockIdToFileMap.clear()
+    temporaryFilesCreated.clear()
+  }
+
+  test("write empty iterator") {
+    val writer = new BypassMergeSortShuffleWriter[Int, Int](
+      new SparkConf(loadDefaults = false),
+      blockManager,
+      new HashPartitioner(7),
+      shuffleWriteMetrics,
+      serializer
+    )
+    writer.insertAll(Iterator.empty)
+    val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
+    assert(partitionLengths.sum === 0)
+    assert(outputFile.exists())
+    assert(outputFile.length() === 0)
+    assert(temporaryFilesCreated.isEmpty)
+    assert(shuffleWriteMetrics.shuffleBytesWritten === 0)
+    assert(shuffleWriteMetrics.shuffleRecordsWritten === 0)
+    assert(taskMetrics.diskBytesSpilled === 0)
+    assert(taskMetrics.memoryBytesSpilled === 0)
+  }
+
+  test("write with some empty partitions") {
+    def records: Iterator[(Int, Int)] =
+      Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
+    val writer = new BypassMergeSortShuffleWriter[Int, Int](
+      new SparkConf(loadDefaults = false),
+      blockManager,
+      new HashPartitioner(7),
+      shuffleWriteMetrics,
+      serializer
+    )
+    writer.insertAll(records)
+    assert(temporaryFilesCreated.nonEmpty)
+    val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
+    assert(partitionLengths.sum === outputFile.length())
+    assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted
+    assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length())
+    assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length)
+    assert(taskMetrics.diskBytesSpilled === 0)
+    assert(taskMetrics.memoryBytesSpilled === 0)
+  }
+
+  test("cleanup of intermediate files after errors") {
+    val writer = new BypassMergeSortShuffleWriter[Int, Int](
+      new SparkConf(loadDefaults = false),
+      blockManager,
+      new HashPartitioner(7),
+      shuffleWriteMetrics,
+      serializer
+    )
+    intercept[SparkException] {
+      writer.insertAll((0 until 100000).iterator.map(i => {
+        if (i == 99990) {
+          throw new SparkException("Intentional failure")
+        }
+        (i, i)
+      }))
+    }
+    assert(temporaryFilesCreated.nonEmpty)
+    writer.stop()
+    assert(temporaryFilesCreated.count(_.exists()) === 0)
+  }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..c6ada7139c198ac6d5de152b002b195ca81453dc
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import org.mockito.Mockito._
+import org.scalatest.FunSuite
+
+import org.apache.spark.{Aggregator, SparkConf}
+
+class SortShuffleWriterSuite extends FunSuite {
+
+  import SortShuffleWriter._
+
+  test("conditions for bypassing merge-sort") {
+    val conf = new SparkConf(loadDefaults = false)
+    val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS)
+    val ord = implicitly[Ordering[Int]]
+
+    // Numbers of partitions that are above and below the default bypassMergeThreshold
+    val FEW_PARTITIONS = 50
+    val MANY_PARTITIONS = 10000
+
+    // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high
+    assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None))
+    assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None))
+
+    // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions
+    assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord)))
+    assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None))
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
index ad43a3e5fdc8834979d98b65d608741fe4269acb..7bdea724fea585b5b556db275421e55dc09793b3 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
@@ -18,14 +18,28 @@ package org.apache.spark.storage
 
 import java.io.File
 
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkConf
 import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.serializer.JavaSerializer
 import org.apache.spark.util.Utils
 
-class BlockObjectWriterSuite extends SparkFunSuite {
+class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+  var tempDir: File = _
+
+  override def beforeEach(): Unit = {
+    tempDir = Utils.createTempDir()
+  }
+
+  override def afterEach(): Unit = {
+    Utils.deleteRecursively(tempDir)
+  }
+
   test("verify write metrics") {
-    val file = new File(Utils.createTempDir(), "somefile")
+    val file = new File(tempDir, "somefile")
     val writeMetrics = new ShuffleWriteMetrics()
     val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
       new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
@@ -47,7 +61,7 @@ class BlockObjectWriterSuite extends SparkFunSuite {
   }
 
   test("verify write metrics on revert") {
-    val file = new File(Utils.createTempDir(), "somefile")
+    val file = new File(tempDir, "somefile")
     val writeMetrics = new ShuffleWriteMetrics()
     val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
       new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
@@ -70,7 +84,7 @@ class BlockObjectWriterSuite extends SparkFunSuite {
   }
 
   test("Reopening a closed block writer") {
-    val file = new File(Utils.createTempDir(), "somefile")
+    val file = new File(tempDir, "somefile")
     val writeMetrics = new ShuffleWriteMetrics()
     val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
       new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
@@ -81,4 +95,79 @@ class BlockObjectWriterSuite extends SparkFunSuite {
       writer.open()
     }
   }
+
+  test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    for (i <- 1 to 1000) {
+      writer.write(i, i)
+    }
+    writer.commitAndClose()
+    val bytesWritten = writeMetrics.shuffleBytesWritten
+    assert(writeMetrics.shuffleRecordsWritten === 1000)
+    writer.revertPartialWritesAndClose()
+    assert(writeMetrics.shuffleRecordsWritten === 1000)
+    assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+  }
+
+  test("commitAndClose() should be idempotent") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    for (i <- 1 to 1000) {
+      writer.write(i, i)
+    }
+    writer.commitAndClose()
+    val bytesWritten = writeMetrics.shuffleBytesWritten
+    val writeTime = writeMetrics.shuffleWriteTime
+    assert(writeMetrics.shuffleRecordsWritten === 1000)
+    writer.commitAndClose()
+    assert(writeMetrics.shuffleRecordsWritten === 1000)
+    assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+    assert(writeMetrics.shuffleWriteTime === writeTime)
+  }
+
+  test("revertPartialWritesAndClose() should be idempotent") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    for (i <- 1 to 1000) {
+      writer.write(i, i)
+    }
+    writer.revertPartialWritesAndClose()
+    val bytesWritten = writeMetrics.shuffleBytesWritten
+    val writeTime = writeMetrics.shuffleWriteTime
+    assert(writeMetrics.shuffleRecordsWritten === 0)
+    writer.revertPartialWritesAndClose()
+    assert(writeMetrics.shuffleRecordsWritten === 0)
+    assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+    assert(writeMetrics.shuffleWriteTime === writeTime)
+  }
+
+  test("fileSegment() can only be called after commitAndClose() has been called") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    for (i <- 1 to 1000) {
+      writer.write(i, i)
+    }
+    intercept[IllegalStateException] {
+      writer.fileSegment()
+    }
+    writer.close()
+  }
+
+  test("commitAndClose() without ever opening or writing") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    writer.commitAndClose()
+    assert(writer.fileSegment().length === 0)
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 9039dbef1fb718d3d027d876ba16a8b5c1a285fb..7d7b41bc23284740e8f7228054d044be0a5c4100 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -23,10 +23,12 @@ import org.scalatest.PrivateMethodTester
 
 import scala.util.Random
 
+import org.scalatest.FunSuite
+
 import org.apache.spark._
 import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
 
-class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with PrivateMethodTester {
+class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
   private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = {
     val conf = new SparkConf(loadDefaults)
     if (kryo) {
@@ -37,21 +39,12 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
       conf.set("spark.serializer.objectStreamReset", "1")
       conf.set("spark.serializer", classOf[JavaSerializer].getName)
     }
+    conf.set("spark.shuffle.sort.bypassMergeThreshold", "0")
     // Ensure that we actually have multiple batches per spill file
     conf.set("spark.shuffle.spill.batchSize", "10")
     conf
   }
 
-  private def assertBypassedMergeSort(sorter: ExternalSorter[_, _, _]): Unit = {
-    val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort)
-    assert(sorter.invokePrivate(bypassMergeSort()), "sorter did not bypass merge-sort")
-  }
-
-  private def assertDidNotBypassMergeSort(sorter: ExternalSorter[_, _, _]): Unit = {
-    val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort)
-    assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed merge-sort")
-  }
-
   test("empty data stream with kryo ser") {
     emptyDataStream(createSparkConf(false, true))
   }
@@ -161,39 +154,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
 
     val sorter = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(7)), Some(ord), None)
-    assertDidNotBypassMergeSort(sorter)
-    sorter.insertAll(elements)
-    assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
-    val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
-    assert(iter.next() === (0, Nil))
-    assert(iter.next() === (1, List((1, 1))))
-    assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList))
-    assert(iter.next() === (3, Nil))
-    assert(iter.next() === (4, Nil))
-    assert(iter.next() === (5, List((5, 5))))
-    assert(iter.next() === (6, Nil))
-    sorter.stop()
-  }
-
-  test("empty partitions with spilling, bypass merge-sort with kryo ser") {
-    emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, true))
-  }
-
-  test("empty partitions with spilling, bypass merge-sort with java ser") {
-    emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, false))
-  }
-
-  def emptyPartitionerWithSpillingBypassMergeSort(conf: SparkConf) {
-    conf.set("spark.shuffle.memoryFraction", "0.001")
-    conf.set("spark.shuffle.spill.initialMemoryThreshold", "512")
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-
-    val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
-
-    val sorter = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(7)), None, None)
-    assertBypassedMergeSort(sorter)
     sorter.insertAll(elements)
     assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
     val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
@@ -376,7 +336,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
 
     val sorter = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(3)), Some(ord), None)
-    assertDidNotBypassMergeSort(sorter)
     sorter.insertAll((0 until 120000).iterator.map(i => (i, i)))
     assert(diskBlockManager.getAllFiles().length > 0)
     sorter.stop()
@@ -384,7 +343,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
 
     val sorter2 = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(3)), Some(ord), None)
-    assertDidNotBypassMergeSort(sorter2)
     sorter2.insertAll((0 until 120000).iterator.map(i => (i, i)))
     assert(diskBlockManager.getAllFiles().length > 0)
     assert(sorter2.iterator.toSet === (0 until 120000).map(i => (i, i)).toSet)
@@ -392,29 +350,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
     assert(diskBlockManager.getAllBlocks().length === 0)
   }
 
-  test("cleanup of intermediate files in sorter, bypass merge-sort") {
-    val conf = createSparkConf(true, false)  // Load defaults, otherwise SPARK_HOME is not found
-    conf.set("spark.shuffle.memoryFraction", "0.001")
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
-
-    val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
-    assertBypassedMergeSort(sorter)
-    sorter.insertAll((0 until 100000).iterator.map(i => (i, i)))
-    assert(diskBlockManager.getAllFiles().length > 0)
-    sorter.stop()
-    assert(diskBlockManager.getAllBlocks().length === 0)
-
-    val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
-    assertBypassedMergeSort(sorter2)
-    sorter2.insertAll((0 until 100000).iterator.map(i => (i, i)))
-    assert(diskBlockManager.getAllFiles().length > 0)
-    assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet)
-    sorter2.stop()
-    assert(diskBlockManager.getAllBlocks().length === 0)
-  }
-
   test("cleanup of intermediate files in sorter if there are errors") {
     val conf = createSparkConf(true, false)  // Load defaults, otherwise SPARK_HOME is not found
     conf.set("spark.shuffle.memoryFraction", "0.001")
@@ -426,7 +361,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
 
     val sorter = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(3)), Some(ord), None)
-    assertDidNotBypassMergeSort(sorter)
     intercept[SparkException] {
       sorter.insertAll((0 until 120000).iterator.map(i => {
         if (i == 119990) {
@@ -440,28 +374,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
     assert(diskBlockManager.getAllBlocks().length === 0)
   }
 
-  test("cleanup of intermediate files in sorter if there are errors, bypass merge-sort") {
-    val conf = createSparkConf(true, false)  // Load defaults, otherwise SPARK_HOME is not found
-    conf.set("spark.shuffle.memoryFraction", "0.001")
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
-
-    val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
-    assertBypassedMergeSort(sorter)
-    intercept[SparkException] {
-      sorter.insertAll((0 until 100000).iterator.map(i => {
-        if (i == 99990) {
-          throw new SparkException("Intentional failure")
-        }
-        (i, i)
-      }))
-    }
-    assert(diskBlockManager.getAllFiles().length > 0)
-    sorter.stop()
-    assert(diskBlockManager.getAllBlocks().length === 0)
-  }
-
   test("cleanup of intermediate files in shuffle") {
     val conf = createSparkConf(false, false)
     conf.set("spark.shuffle.memoryFraction", "0.001")
@@ -776,40 +688,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
     }
   }
 
-  test("conditions for bypassing merge-sort") {
-    val conf = createSparkConf(false, false)
-    conf.set("spark.shuffle.memoryFraction", "0.001")
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-
-    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
-    val ord = implicitly[Ordering[Int]]
-
-    // Numbers of partitions that are above and below the default bypassMergeThreshold
-    val FEW_PARTITIONS = 50
-    val MANY_PARTITIONS = 10000
-
-    // Sorters with no ordering or aggregator: should bypass unless # of partitions is high
-
-    val sorter1 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(FEW_PARTITIONS)), None, None)
-    assertBypassedMergeSort(sorter1)
-
-    val sorter2 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(MANY_PARTITIONS)), None, None)
-    assertDidNotBypassMergeSort(sorter2)
-
-    // Sorters with an ordering or aggregator: should not bypass even if they have few partitions
-
-    val sorter3 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(FEW_PARTITIONS)), Some(ord), None)
-    assertDidNotBypassMergeSort(sorter3)
-
-    val sorter4 = new ExternalSorter[Int, Int, Int](
-      Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None)
-    assertDidNotBypassMergeSort(sorter4)
-  }
-
   test("sort without breaking sorting contracts with kryo ser") {
     sortWithoutBreakingSortingContracts(createSparkConf(true, true))
   }