diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 9d4edeb6d96cff650bf53d0d97d816e28272a7fa..22d8d1cb1ddcfe48006d31b6f020210988c332d6 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -156,11 +156,9 @@ object SparkEnv extends Logging {
       conf.set("spark.driver.port", boundPort.toString)
     }
 
-    // Create an instance of the class named by the given Java system property, or by
-    // defaultClassName if the property is not set, and return it as a T
-    def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
-      val name = conf.get(propertyName,  defaultClassName)
-      val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader)
+    // Create an instance of the class with the given name, possibly initializing it with our conf
+    def instantiateClass[T](className: String): T = {
+      val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader)
       // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
       // SparkConf, then one taking no arguments
       try {
@@ -178,11 +176,17 @@ object SparkEnv extends Logging {
       }
     }
 
-    val serializer = instantiateClass[Serializer](
+    // Create an instance of the class named by the given SparkConf property, or defaultClassName
+    // if the property is not set, possibly initializing it with our conf
+    def instantiateClassFromConf[T](propertyName: String, defaultClassName: String): T = {
+      instantiateClass[T](conf.get(propertyName, defaultClassName))
+    }
+
+    val serializer = instantiateClassFromConf[Serializer](
       "spark.serializer", "org.apache.spark.serializer.JavaSerializer")
     logDebug(s"Using serializer: ${serializer.getClass}")
 
-    val closureSerializer = instantiateClass[Serializer](
+    val closureSerializer = instantiateClassFromConf[Serializer](
       "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")
 
     def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
@@ -246,8 +250,13 @@ object SparkEnv extends Logging {
       "."
     }
 
-    val shuffleManager = instantiateClass[ShuffleManager](
-      "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
+    // Let the user specify short names for shuffle managers
+    val shortShuffleMgrNames = Map(
+      "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
+      "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
+    val shuffleMgrName = conf.get("spark.shuffle.manager", "hash")
+    val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
+    val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
 
     val shuffleMemoryManager = new ShuffleMemoryManager(conf)
 
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 7c9dc8e5f88efdc9f261e728b0e978ed71cbe56c..88a5f1e5ddf587eab7d3e1580db2d776581c9973 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -58,7 +58,7 @@ private[spark] class HashShuffleReader[K, C](
         // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
         // the ExternalSorter won't spill to disk.
         val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
-        sorter.write(aggregatedIter)
+        sorter.insertAll(aggregatedIter)
         context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
         context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
         sorter.iterator
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index e54e6383d2ccc112d588caa11e5bfc84ec994e6d..22f656fa371ea3af75538475e03c92dcf192377b 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -44,6 +44,7 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   private var sorter: ExternalSorter[K, V, _] = null
   private var outputFile: File = null
+  private var indexFile: File = null
 
   // Are we in the process of stopping? Because map tasks can call stop() with success = true
   // and then call stop() with success = false if they get an exception, we want to make sure
@@ -57,78 +58,36 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   /** Write a bunch of records to this task's output */
   override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
-    // Get an iterator with the elements for each partition ID
-    val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = {
-      if (dep.mapSideCombine) {
-        if (!dep.aggregator.isDefined) {
-          throw new IllegalStateException("Aggregator is empty for map-side combine")
-        }
-        sorter = new ExternalSorter[K, V, C](
-          dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
-        sorter.write(records)
-        sorter.partitionedIterator
-      } else {
-        // In this case we pass neither an aggregator nor an ordering to the sorter, because we
-        // don't care whether the keys get sorted in each partition; that will be done on the
-        // reduce side if the operation being run is sortByKey.
-        sorter = new ExternalSorter[K, V, V](
-          None, Some(dep.partitioner), None, dep.serializer)
-        sorter.write(records)
-        sorter.partitionedIterator
+    if (dep.mapSideCombine) {
+      if (!dep.aggregator.isDefined) {
+        throw new IllegalStateException("Aggregator is empty for map-side combine")
       }
+      sorter = new ExternalSorter[K, V, C](
+        dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
+      sorter.insertAll(records)
+    } else {
+      // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
+      // care whether the keys get sorted in each partition; that will be done on the reduce side
+      // if the operation being run is sortByKey.
+      sorter = new ExternalSorter[K, V, V](
+        None, Some(dep.partitioner), None, dep.serializer)
+      sorter.insertAll(records)
     }
 
     // Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later
     // serve different ranges of this file using an index file that we create at the end.
     val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0)
-    outputFile = blockManager.diskBlockManager.getFile(blockId)
-
-    // Track location of each range in the output file
-    val offsets = new Array[Long](numPartitions + 1)
-    val lengths = new Array[Long](numPartitions)
-
-    for ((id, elements) <- partitions) {
-      if (elements.hasNext) {
-        val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize,
-          writeMetrics)
-        for (elem <- elements) {
-          writer.write(elem)
-        }
-        writer.commitAndClose()
-        val segment = writer.fileSegment()
-        offsets(id + 1) = segment.offset + segment.length
-        lengths(id) = segment.length
-      } else {
-        // The partition is empty; don't create a new writer to avoid writing headers, etc
-        offsets(id + 1) = offsets(id)
-      }
-    }
-
-    context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
-    context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
 
-    // Write an index file with the offsets of each block, plus a final offset at the end for the
-    // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure
-    // out where each block begins and ends.
+    outputFile = blockManager.diskBlockManager.getFile(blockId)
+    indexFile = blockManager.diskBlockManager.getFile(blockId.name + ".index")
 
-    val diskBlockManager = blockManager.diskBlockManager
-    val indexFile = diskBlockManager.getFile(blockId.name + ".index")
-    val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
-    try {
-      var i = 0
-      while (i < numPartitions + 1) {
-        out.writeLong(offsets(i))
-        i += 1
-      }
-    } finally {
-      out.close()
-    }
+    val partitionLengths = sorter.writePartitionedFile(blockId, context)
 
     // Register our map output with the ShuffleBlockManager, which handles cleaning it over time
     blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions)
 
     mapStatus = new MapStatus(blockManager.blockManagerId,
-      lengths.map(MapOutputTracker.compressSize))
+      partitionLengths.map(MapOutputTracker.compressSize))
   }
 
   /** Close this writer, passing along whether the map completed */
@@ -145,6 +104,9 @@ private[spark] class SortShuffleWriter[K, V, C](
         if (outputFile != null) {
           outputFile.delete()
         }
+        if (indexFile != null) {
+          indexFile.delete()
+        }
         return None
       }
     } finally {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index eb4849ebc6e52357b36929ced1258859d53b7602..b73d5e0cf1714491b20231b1b50dc7ee16991873 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -25,10 +25,10 @@ import scala.collection.mutable
 
 import com.google.common.io.ByteStreams
 
-import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
+import org.apache.spark._
 import org.apache.spark.serializer.{DeserializationStream, Serializer}
-import org.apache.spark.storage.BlockId
 import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.storage.{BlockObjectWriter, BlockId}
 
 /**
  * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -67,6 +67,13 @@ import org.apache.spark.executor.ShuffleWriteMetrics
  *   for equality to merge values.
  *
  * - Users are expected to call stop() at the end to delete all the intermediate files.
+ *
+ * As a special case, if no Ordering and no Aggregator is given, and the number of partitions is
+ * less than spark.shuffle.sort.bypassMergeThreshold, we bypass the merge-sort and just write to
+ * separate files for each partition each time we spill, similar to the HashShuffleWriter. We can
+ * then concatenate these files to produce a single sorted file, without having to serialize and
+ * de-serialize each item twice (as is needed during the merge). This speeds up the map side of
+ * groupBy, sort, etc operations since they do no partial aggregation.
  */
 private[spark] class ExternalSorter[K, V, C](
     aggregator: Option[Aggregator[K, V, C]] = None,
@@ -124,6 +131,18 @@ private[spark] class ExternalSorter[K, V, C](
   // How much of the shared memory pool this collection has claimed
   private var myMemoryThreshold = 0L
 
+  // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need
+  // local aggregation and sorting, write numPartitions files directly and just concatenate them
+  // at the end. This avoids doing serialization and deserialization twice to merge together the
+  // spilled files, which would happen with the normal code path. The downside is having multiple
+  // files open at a time and thus more memory allocated to buffers.
+  private val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+  private val bypassMergeSort =
+    (numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty)
+
+  // Array of file writers for each partition, used if bypassMergeSort is true and we've spilled
+  private var partitionWriters: Array[BlockObjectWriter] = null
+
   // A comparator for keys K that orders them within a partition to allow aggregation or sorting.
   // Can be a partial ordering by hash code if a total ordering is not provided through by the
   // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
@@ -137,7 +156,14 @@ private[spark] class ExternalSorter[K, V, C](
     }
   })
 
-  // A comparator for (Int, K) elements that orders them by partition and then possibly by key
+  // A comparator for (Int, K) pairs that orders them by only their partition ID
+  private val partitionComparator: Comparator[(Int, K)] = new Comparator[(Int, K)] {
+    override def compare(a: (Int, K), b: (Int, K)): Int = {
+      a._1 - b._1
+    }
+  }
+
+  // A comparator that orders (Int, K) pairs by partition ID and then possibly by key
   private val partitionKeyComparator: Comparator[(Int, K)] = {
     if (ordering.isDefined || aggregator.isDefined) {
       // Sort by partition ID then key comparator
@@ -153,11 +179,7 @@ private[spark] class ExternalSorter[K, V, C](
       }
     } else {
       // Just sort it by partition ID
-      new Comparator[(Int, K)] {
-        override def compare(a: (Int, K), b: (Int, K)): Int = {
-          a._1 - b._1
-        }
-      }
+      partitionComparator
     }
   }
 
@@ -171,7 +193,7 @@ private[spark] class ExternalSorter[K, V, C](
     elementsPerPartition: Array[Long])
   private val spills = new ArrayBuffer[SpilledFile]
 
-  def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+  def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = {
     // TODO: stop combining if we find that the reduction factor isn't high
     val shouldCombine = aggregator.isDefined
 
@@ -242,6 +264,38 @@ private[spark] class ExternalSorter[K, V, C](
     val threadId = Thread.currentThread().getId
     logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)"
       .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
+
+    if (bypassMergeSort) {
+      spillToPartitionFiles(collection)
+    } else {
+      spillToMergeableFile(collection)
+    }
+
+    if (usingMap) {
+      map = new SizeTrackingAppendOnlyMap[(Int, K), C]
+    } else {
+      buffer = new SizeTrackingPairBuffer[(Int, K), C]
+    }
+
+    // Release our memory back to the shuffle pool so that other threads can grab it
+    shuffleMemoryManager.release(myMemoryThreshold)
+    myMemoryThreshold = 0
+
+    _memoryBytesSpilled += memorySize
+  }
+
+  /**
+   * Spill our in-memory collection to a sorted file that we can merge later (normal code path).
+   * We add this file into spilledFiles to find it later.
+   *
+   * Alternatively, if bypassMergeSort is true, we spill to separate files for each partition.
+   * See spillToPartitionedFiles() for that code path.
+   *
+   * @param collection whichever collection we're using (map or buffer)
+   */
+  private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
+    assert(!bypassMergeSort)
+
     val (blockId, file) = diskBlockManager.createTempBlock()
     curWriteMetrics = new ShuffleWriteMetrics()
     var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
@@ -304,18 +358,36 @@ private[spark] class ExternalSorter[K, V, C](
       }
     }
 
-    if (usingMap) {
-      map = new SizeTrackingAppendOnlyMap[(Int, K), C]
-    } else {
-      buffer = new SizeTrackingPairBuffer[(Int, K), C]
-    }
+    spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
+  }
 
-    // Release our memory back to the shuffle pool so that other threads can grab it
-    shuffleMemoryManager.release(myMemoryThreshold)
-    myMemoryThreshold = 0
+  /**
+   * Spill our in-memory collection to separate files, one for each partition. This is used when
+   * there's no aggregator and ordering and the number of partitions is small, because it allows
+   * writePartitionedFile to just concatenate files without deserializing data.
+   *
+   * @param collection whichever collection we're using (map or buffer)
+   */
+  private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
+    assert(bypassMergeSort)
+
+    // Create our file writers if we haven't done so yet
+    if (partitionWriters == null) {
+      curWriteMetrics = new ShuffleWriteMetrics()
+      partitionWriters = Array.fill(numPartitions) {
+        val (blockId, file) = diskBlockManager.createTempBlock()
+        blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open()
+      }
+    }
 
-    spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
-    _memoryBytesSpilled += memorySize
+    val it = collection.iterator  // No need to sort stuff, just write each element out
+    while (it.hasNext) {
+      val elem = it.next()
+      val partitionId = elem._1._1
+      val key = elem._1._2
+      val value = elem._2
+      partitionWriters(partitionId).write((key, value))
+    }
   }
 
   /**
@@ -479,7 +551,6 @@ private[spark] class ExternalSorter[K, V, C](
 
     skipToNextPartition()
 
-
     // Intermediate file and deserializer streams that read from exactly one batch
     // This guards against pre-fetching and other arbitrary behavior of higher level streams
     var fileStream: FileInputStream = null
@@ -619,23 +690,25 @@ private[spark] class ExternalSorter[K, V, C](
   def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
     val usingMap = aggregator.isDefined
     val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
-    if (spills.isEmpty) {
+    if (spills.isEmpty && partitionWriters == null) {
       // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
       // we don't even need to sort by anything other than partition ID
       if (!ordering.isDefined) {
-        // The user isn't requested sorted keys, so only sort by partition ID, not key
-        val partitionComparator = new Comparator[(Int, K)] {
-          override def compare(a: (Int, K), b: (Int, K)): Int = {
-            a._1 - b._1
-          }
-        }
+        // The user hasn't requested sorted keys, so only sort by partition ID, not key
         groupByPartition(collection.destructiveSortedIterator(partitionComparator))
       } else {
         // We do need to sort by both partition ID and key
         groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator))
       }
+    } else if (bypassMergeSort) {
+      // Read data from each partition file and merge it together with the data in memory;
+      // note that there's no ordering or aggregator in this case -- we just partition objects
+      val collIter = groupByPartition(collection.destructiveSortedIterator(partitionComparator))
+      collIter.map { case (partitionId, values) =>
+        (partitionId, values ++ readPartitionFile(partitionWriters(partitionId)))
+      }
     } else {
-      // General case: merge spilled and in-memory data
+      // Merge spilled and in-memory data
       merge(spills, collection.destructiveSortedIterator(partitionKeyComparator))
     }
   }
@@ -645,9 +718,113 @@ private[spark] class ExternalSorter[K, V, C](
    */
   def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2)
 
+  /**
+   * Write all the data added into this ExternalSorter into a file in the disk store, creating
+   * an .index file for it as well with the offsets of each partition. This is called by the
+   * SortShuffleWriter and can go through an efficient path of just concatenating binary files
+   * if we decided to avoid merge-sorting.
+   *
+   * @param blockId block ID to write to. The index file will be blockId.name + ".index".
+   * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
+   * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
+   */
+  def writePartitionedFile(blockId: BlockId, context: TaskContext): Array[Long] = {
+    val outputFile = blockManager.diskBlockManager.getFile(blockId)
+
+    // Track location of each range in the output file
+    val offsets = new Array[Long](numPartitions + 1)
+    val lengths = new Array[Long](numPartitions)
+
+    if (bypassMergeSort && partitionWriters != null) {
+      // We decided to write separate files for each partition, so just concatenate them. To keep
+      // this simple we spill out the current in-memory collection so that everything is in files.
+      spillToPartitionFiles(if (aggregator.isDefined) map else buffer)
+      partitionWriters.foreach(_.commitAndClose())
+      var out: FileOutputStream = null
+      var in: FileInputStream = null
+      try {
+        out = new FileOutputStream(outputFile)
+        for (i <- 0 until numPartitions) {
+          val file = partitionWriters(i).fileSegment().file
+          in = new FileInputStream(file)
+          org.apache.spark.util.Utils.copyStream(in, out)
+          in.close()
+          in = null
+          lengths(i) = file.length()
+          offsets(i + 1) = offsets(i) + lengths(i)
+        }
+      } finally {
+        if (out != null) {
+          out.close()
+        }
+        if (in != null) {
+          in.close()
+        }
+      }
+    } else {
+      // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by
+      // partition and just write everything directly.
+      for ((id, elements) <- this.partitionedIterator) {
+        if (elements.hasNext) {
+          val writer = blockManager.getDiskWriter(
+            blockId, outputFile, ser, fileBufferSize, context.taskMetrics.shuffleWriteMetrics.get)
+          for (elem <- elements) {
+            writer.write(elem)
+          }
+          writer.commitAndClose()
+          val segment = writer.fileSegment()
+          offsets(id + 1) = segment.offset + segment.length
+          lengths(id) = segment.length
+        } else {
+          // The partition is empty; don't create a new writer to avoid writing headers, etc
+          offsets(id + 1) = offsets(id)
+        }
+      }
+    }
+
+    context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
+    context.taskMetrics.diskBytesSpilled += diskBytesSpilled
+
+    // Write an index file with the offsets of each block, plus a final offset at the end for the
+    // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure
+    // out where each block begins and ends.
+
+    val diskBlockManager = blockManager.diskBlockManager
+    val indexFile = diskBlockManager.getFile(blockId.name + ".index")
+    val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
+    try {
+      var i = 0
+      while (i < numPartitions + 1) {
+        out.writeLong(offsets(i))
+        i += 1
+      }
+    } finally {
+      out.close()
+    }
+
+    lengths
+  }
+
+  /**
+   * Read a partition file back as an iterator (used in our iterator method)
+   */
+  def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = {
+    if (writer.isOpen) {
+      writer.commitAndClose()
+    }
+    blockManager.getLocalFromDisk(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]]
+  }
+
   def stop(): Unit = {
     spills.foreach(s => s.file.delete())
     spills.clear()
+    if (partitionWriters != null) {
+      partitionWriters.foreach { w =>
+        w.revertPartialWritesAndClose()
+        diskBlockManager.getFile(w.blockId).delete()
+      }
+      partitionWriters = null
+    }
   }
 
   def memoryBytesSpilled: Long = _memoryBytesSpilled
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 57dcb4ffabac1f679bc69a059f9d4a1442cedb42..706faed980f31d0e2d9ef215ed0a377cdf183fe1 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -19,12 +19,12 @@ package org.apache.spark.util.collection
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.scalatest.FunSuite
+import org.scalatest.{PrivateMethodTester, FunSuite}
 
 import org.apache.spark._
 import org.apache.spark.SparkContext._
 
-class ExternalSorterSuite extends FunSuite with LocalSparkContext {
+class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester {
   private def createSparkConf(loadDefaults: Boolean): SparkConf = {
     val conf = new SparkConf(loadDefaults)
     // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
@@ -36,6 +36,16 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
     conf
   }
 
+  private def assertBypassedMergeSort(sorter: ExternalSorter[_, _, _]): Unit = {
+    val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort)
+    assert(sorter.invokePrivate(bypassMergeSort()), "sorter did not bypass merge-sort")
+  }
+
+  private def assertDidNotBypassMergeSort(sorter: ExternalSorter[_, _, _]): Unit = {
+    val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort)
+    assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed merge-sort")
+  }
+
   test("empty data stream") {
     val conf = new SparkConf(false)
     conf.set("spark.shuffle.memoryFraction", "0.001")
@@ -86,28 +96,28 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
     // Both aggregator and ordering
     val sorter = new ExternalSorter[Int, Int, Int](
       Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
-    sorter.write(elements.iterator)
+    sorter.insertAll(elements.iterator)
     assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
     sorter.stop()
 
     // Only aggregator
     val sorter2 = new ExternalSorter[Int, Int, Int](
       Some(agg), Some(new HashPartitioner(7)), None, None)
-    sorter2.write(elements.iterator)
+    sorter2.insertAll(elements.iterator)
     assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
     sorter2.stop()
 
     // Only ordering
     val sorter3 = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(7)), Some(ord), None)
-    sorter3.write(elements.iterator)
+    sorter3.insertAll(elements.iterator)
     assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
     sorter3.stop()
 
     // Neither aggregator nor ordering
     val sorter4 = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(7)), None, None)
-    sorter4.write(elements.iterator)
+    sorter4.insertAll(elements.iterator)
     assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
     sorter4.stop()
   }
@@ -118,13 +128,37 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
 
-    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
     val ord = implicitly[Ordering[Int]]
     val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
 
+    val sorter = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(7)), Some(ord), None)
+    assertDidNotBypassMergeSort(sorter)
+    sorter.insertAll(elements)
+    assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
+    val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
+    assert(iter.next() === (0, Nil))
+    assert(iter.next() === (1, List((1, 1))))
+    assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList))
+    assert(iter.next() === (3, Nil))
+    assert(iter.next() === (4, Nil))
+    assert(iter.next() === (5, List((5, 5))))
+    assert(iter.next() === (6, Nil))
+    sorter.stop()
+  }
+
+  test("empty partitions with spilling, bypass merge-sort") {
+    val conf = createSparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+
+    val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
+
     val sorter = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(7)), None, None)
-    sorter.write(elements)
+    assertBypassedMergeSort(sorter)
+    sorter.insertAll(elements)
     assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
     val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
     assert(iter.next() === (0, Nil))
@@ -286,14 +320,43 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
     sc = new SparkContext("local", "test", conf)
     val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
 
+    val ord = implicitly[Ordering[Int]]
+
+    val sorter = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(3)), Some(ord), None)
+    assertDidNotBypassMergeSort(sorter)
+    sorter.insertAll((0 until 100000).iterator.map(i => (i, i)))
+    assert(diskBlockManager.getAllFiles().length > 0)
+    sorter.stop()
+    assert(diskBlockManager.getAllBlocks().length === 0)
+
+    val sorter2 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(3)), Some(ord), None)
+    assertDidNotBypassMergeSort(sorter2)
+    sorter2.insertAll((0 until 100000).iterator.map(i => (i, i)))
+    assert(diskBlockManager.getAllFiles().length > 0)
+    assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet)
+    sorter2.stop()
+    assert(diskBlockManager.getAllBlocks().length === 0)
+  }
+
+  test("cleanup of intermediate files in sorter, bypass merge-sort") {
+    val conf = createSparkConf(true)  // Load defaults, otherwise SPARK_HOME is not found
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
     val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
-    sorter.write((0 until 100000).iterator.map(i => (i, i)))
+    assertBypassedMergeSort(sorter)
+    sorter.insertAll((0 until 100000).iterator.map(i => (i, i)))
     assert(diskBlockManager.getAllFiles().length > 0)
     sorter.stop()
     assert(diskBlockManager.getAllBlocks().length === 0)
 
     val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
-    sorter2.write((0 until 100000).iterator.map(i => (i, i)))
+    assertBypassedMergeSort(sorter2)
+    sorter2.insertAll((0 until 100000).iterator.map(i => (i, i)))
     assert(diskBlockManager.getAllFiles().length > 0)
     assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet)
     sorter2.stop()
@@ -307,9 +370,35 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
     sc = new SparkContext("local", "test", conf)
     val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
 
+    val ord = implicitly[Ordering[Int]]
+
+    val sorter = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(3)), Some(ord), None)
+    assertDidNotBypassMergeSort(sorter)
+    intercept[SparkException] {
+      sorter.insertAll((0 until 100000).iterator.map(i => {
+        if (i == 99990) {
+          throw new SparkException("Intentional failure")
+        }
+        (i, i)
+      }))
+    }
+    assert(diskBlockManager.getAllFiles().length > 0)
+    sorter.stop()
+    assert(diskBlockManager.getAllBlocks().length === 0)
+  }
+
+  test("cleanup of intermediate files in sorter if there are errors, bypass merge-sort") {
+    val conf = createSparkConf(true)  // Load defaults, otherwise SPARK_HOME is not found
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
     val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
+    assertBypassedMergeSort(sorter)
     intercept[SparkException] {
-      sorter.write((0 until 100000).iterator.map(i => {
+      sorter.insertAll((0 until 100000).iterator.map(i => {
         if (i == 99990) {
           throw new SparkException("Intentional failure")
         }
@@ -365,7 +454,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
     sc = new SparkContext("local", "test", conf)
 
     val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
-    sorter.write((0 until 100000).iterator.map(i => (i / 4, i)))
+    sorter.insertAll((0 until 100000).iterator.map(i => (i / 4, i)))
     val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
     val expected = (0 until 3).map(p => {
       (p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet)
@@ -381,7 +470,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
 
     val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
     val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None)
-    sorter.write((0 until 100).iterator.map(i => (i / 2, i)))
+    sorter.insertAll((0 until 100).iterator.map(i => (i / 2, i)))
     val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
     val expected = (0 until 3).map(p => {
       (p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
@@ -397,7 +486,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
 
     val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
     val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None)
-    sorter.write((0 until 100000).iterator.map(i => (i / 2, i)))
+    sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i)))
     val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
     val expected = (0 until 3).map(p => {
       (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
@@ -414,7 +503,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
     val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
     val ord = implicitly[Ordering[Int]]
     val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
-    sorter.write((0 until 100000).iterator.map(i => (i / 2, i)))
+    sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i)))
     val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
     val expected = (0 until 3).map(p => {
       (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
@@ -431,7 +520,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
     val ord = implicitly[Ordering[Int]]
     val sorter = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(3)), Some(ord), None)
-    sorter.write((0 until 100).iterator.map(i => (i, i)))
+    sorter.insertAll((0 until 100).iterator.map(i => (i, i)))
     val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq
     val expected = (0 until 3).map(p => {
       (p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq)
@@ -448,7 +537,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
     val ord = implicitly[Ordering[Int]]
     val sorter = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(3)), Some(ord), None)
-    sorter.write((0 until 100000).iterator.map(i => (i, i)))
+    sorter.insertAll((0 until 100000).iterator.map(i => (i, i)))
     val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq
     val expected = (0 until 3).map(p => {
       (p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq)
@@ -495,7 +584,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
     val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++
       collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap)
 
-    sorter.write(toInsert)
+    sorter.insertAll(toInsert)
 
     // A map of collision pairs in both directions
     val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap
@@ -524,7 +613,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
     // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
     // problems if the map fails to group together the objects with the same code (SPARK-2043).
     val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1)
-    sorter.write(toInsert.iterator)
+    sorter.insertAll(toInsert.iterator)
 
     val it = sorter.iterator
     var count = 0
@@ -548,7 +637,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
     val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners)
     val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None)
 
-    sorter.write((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue)))
+    sorter.insertAll((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue)))
 
     val it = sorter.iterator
     while (it.hasNext) {
@@ -572,7 +661,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
     val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
       Some(agg), None, None, None)
 
-    sorter.write((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator(
+    sorter.insertAll((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator(
       (null.asInstanceOf[String], "1"),
       ("1", null.asInstanceOf[String]),
       (null.asInstanceOf[String], null.asInstanceOf[String])
@@ -584,4 +673,38 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
       it.next()
     }
   }
+
+  test("conditions for bypassing merge-sort") {
+    val conf = createSparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+
+    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+    val ord = implicitly[Ordering[Int]]
+
+    // Numbers of partitions that are above and below the default bypassMergeThreshold
+    val FEW_PARTITIONS = 50
+    val MANY_PARTITIONS = 10000
+
+    // Sorters with no ordering or aggregator: should bypass unless # of partitions is high
+
+    val sorter1 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(FEW_PARTITIONS)), None, None)
+    assertBypassedMergeSort(sorter1)
+
+    val sorter2 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(MANY_PARTITIONS)), None, None)
+    assertDidNotBypassMergeSort(sorter2)
+
+    // Sorters with an ordering or aggregator: should not bypass even if they have few partitions
+
+    val sorter3 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(FEW_PARTITIONS)), Some(ord), None)
+    assertDidNotBypassMergeSort(sorter3)
+
+    val sorter4 = new ExternalSorter[Int, Int, Int](
+      Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None)
+    assertDidNotBypassMergeSort(sorter4)
+  }
 }
diff --git a/docs/configuration.md b/docs/configuration.md
index 5e3eb0f0871af4447311a6158b6cd774fefa751c..4d27c5a918fe09c9bb0584bbe0512daff31d316d 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -281,6 +281,24 @@ Apart from these, the following properties are also available, and may be useful
     overhead per reduce task, so keep it small unless you have a large amount of memory.
   </td>
 </tr>
+<tr>
+  <td><code>spark.shuffle.manager</code></td>
+  <td>HASH</td>
+  <td>
+    Implementation to use for shuffling data. A hash-based shuffle manager is the default, but
+    starting in Spark 1.1 there is an experimental sort-based shuffle manager that is more 
+    memory-efficient in environments with small executors, such as YARN. To use that, change
+    this value to <code>SORT</code>.
+  </td>
+</tr>
+<tr>
+  <td><code>spark.shuffle.sort.bypassMergeThreshold</code></td>
+  <td>200</td>
+  <td>
+    (Advanced) In the sort-based shuffle manager, avoid merge-sorting data if there is no
+    map-side aggregation and there are at most this many reduce partitions.
+  </td>
+</tr>
 </table>
 
 #### Spark UI