diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 530712b5df4a80f8233fec55299e65e95bcbc5b4..696b930a26b9e6da3ca70dda4c589def189132e5 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -66,6 +66,11 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
    * Cumulative time spent performing blocking writes, in ns.
    */
   def timeWriting(): Long
+
+  /**
+   * Number of bytes written so far
+   */
+  def bytesWritten: Long
 }
 
 /** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
@@ -183,7 +188,8 @@ private[spark] class DiskBlockObjectWriter(
   // Only valid if called after close()
   override def timeWriting() = _timeWriting
 
-  def bytesWritten: Long = {
+  // Only valid if called after commit()
+  override def bytesWritten: Long = {
     lastValidPosition - initialPosition
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index a8ef7fa8b63ebc69c35ef40cf55e47b747cc0da3..f3e1c38744d78b59aacd0191a2db14f6298bf0cf 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -50,7 +50,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
   addShutdownHook()
 
   /**
-   * Returns the phyiscal file segment in which the given BlockId is located.
+   * Returns the physical file segment in which the given BlockId is located.
    * If the BlockId has been mapped to a specific FileSegment, that will be returned.
    * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
    */
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 3d9b09ec33e2af61a8fc81cb6d7798b2ef6f3ab4..7eb300d46e6e2f6ca7323380675a434d5e96e65e 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -24,11 +24,11 @@ import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
 import it.unimi.dsi.fastutil.io.FastBufferedInputStream
+import com.google.common.io.ByteStreams
 
 import org.apache.spark.{Logging, SparkEnv}
-import org.apache.spark.io.LZFCompressionCodec
-import org.apache.spark.serializer.{KryoDeserializationStream, Serializer}
-import org.apache.spark.storage.{BlockId, BlockManager, DiskBlockObjectWriter}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage.{BlockId, BlockManager}
 
 /**
  * An append-only map that spills sorted content to disk when there is insufficient space for it
@@ -84,12 +84,15 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
   // Number of in-memory pairs inserted before tracking the map's shuffle memory usage
   private val trackMemoryThreshold = 1000
 
-  // Size of object batches when reading/writing from serializers. Objects are written in
-  // batches, with each batch using its own serialization stream. This cuts down on the size
-  // of reference-tracking maps constructed when deserializing a stream.
-  //
-  // NOTE: Setting this too low can cause excess copying when serializing, since some serializers
-  // grow internal data structures by growing + copying every time the number of objects doubles.
+  /**
+   * Size of object batches when reading/writing from serializers.
+   *
+   * Objects are written in batches, with each batch using its own serialization stream. This
+   * cuts down on the size of reference-tracking maps constructed when deserializing a stream.
+   *
+   * NOTE: Setting this too low can cause excessive copying when serializing, since some serializers
+   * grow internal data structures by growing + copying every time the number of objects doubles.
+   */
   private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000)
 
   // How many times we have spilled so far
@@ -100,7 +103,6 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
   private var _diskBytesSpilled = 0L
 
   private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
-  private val syncWrites = sparkConf.getBoolean("spark.shuffle.sync", false)
   private val comparator = new KCComparator[K, C]
   private val ser = serializer.newInstance()
 
@@ -153,37 +155,21 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
     logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)"
       .format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
     val (blockId, file) = diskBlockManager.createTempBlock()
+    var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
+    var objectsWritten = 0
 
-    /* IMPORTANT NOTE: To avoid having to keep large object graphs in memory, this approach
-    *  closes and re-opens serialization and compression streams within each file. This makes some
-     * assumptions about the way that serialization and compression streams work, specifically:
-     *
-     * 1) The serializer input streams do not pre-fetch data from the underlying stream.
-     *
-     * 2) Several compression streams can be opened, written to, and flushed on the write path
-     *    while only one compression input stream is created on the read path
-     *
-     * In practice (1) is only true for Java, so we add a special fix below to make it work for
-     * Kryo. (2) is only true for LZF and not Snappy, so we coerce this to use LZF.
-     *
-     * To avoid making these assumptions we should create an intermediate stream that batches
-     * objects and sends an EOF to the higher layer streams to make sure they never prefetch data.
-     * This is a bit tricky because, within each segment, you'd need to track the total number
-     * of bytes written and then re-wind and write it at the beginning of the segment. This will
-     * most likely require using the file channel API.
-     */
+    // List of batch sizes (bytes) in the order they are written to disk
+    val batchSizes = new ArrayBuffer[Long]
 
-    val shouldCompress = blockManager.shouldCompress(blockId)
-    val compressionCodec = new LZFCompressionCodec(sparkConf)
-    def wrapForCompression(outputStream: OutputStream) = {
-      if (shouldCompress) compressionCodec.compressedOutputStream(outputStream) else outputStream
+    // Flush the disk writer's contents to disk, and update relevant variables
+    def flush() = {
+      writer.commit()
+      val bytesWritten = writer.bytesWritten
+      batchSizes.append(bytesWritten)
+      _diskBytesSpilled += bytesWritten
+      objectsWritten = 0
     }
 
-    def getNewWriter = new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize,
-      wrapForCompression, syncWrites)
-
-    var writer = getNewWriter
-    var objectsWritten = 0
     try {
       val it = currentMap.destructiveSortedIterator(comparator)
       while (it.hasNext) {
@@ -192,22 +178,21 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
         objectsWritten += 1
 
         if (objectsWritten == serializerBatchSize) {
-          writer.commit()
+          flush()
           writer.close()
-          _diskBytesSpilled += writer.bytesWritten
-          writer = getNewWriter
-          objectsWritten = 0
+          writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
         }
       }
-
-      if (objectsWritten > 0) writer.commit()
+      if (objectsWritten > 0) {
+        flush()
+      }
     } finally {
       // Partial failures cannot be tolerated; do not revert partial writes
       writer.close()
-      _diskBytesSpilled += writer.bytesWritten
     }
+
     currentMap = new SizeTrackingAppendOnlyMap[K, C]
-    spilledMaps.append(new DiskMapIterator(file, blockId))
+    spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
 
     // Reset the amount of shuffle memory used by this map in the global pool
     val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
@@ -239,12 +224,12 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
   private class ExternalIterator extends Iterator[(K, C)] {
 
     // A fixed-size queue that maintains a buffer for each stream we are currently merging
-    val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
+    private val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
 
     // Input streams are derived both from the in-memory map and spilled maps on disk
     // The in-memory map is sorted in place, while the spilled maps are already in sorted order
-    val sortedMap = currentMap.destructiveSortedIterator(comparator)
-    val inputStreams = Seq(sortedMap) ++ spilledMaps
+    private val sortedMap = currentMap.destructiveSortedIterator(comparator)
+    private val inputStreams = Seq(sortedMap) ++ spilledMaps
 
     inputStreams.foreach { it =>
       val kcPairs = getMorePairs(it)
@@ -252,11 +237,12 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
     }
 
     /**
-     * Fetch from the given iterator until a key of different hash is retrieved. In the
-     * event of key hash collisions, this ensures no pairs are hidden from being merged.
+     * Fetch from the given iterator until a key of different hash is retrieved.
+     *
+     * In the event of key hash collisions, this ensures no pairs are hidden from being merged.
      * Assume the given iterator is in sorted order.
      */
-    def getMorePairs(it: Iterator[(K, C)]): ArrayBuffer[(K, C)] = {
+    private def getMorePairs(it: Iterator[(K, C)]): ArrayBuffer[(K, C)] = {
       val kcPairs = new ArrayBuffer[(K, C)]
       if (it.hasNext) {
         var kc = it.next()
@@ -274,7 +260,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
      * If the given buffer contains a value for the given key, merge that value into
      * baseCombiner and remove the corresponding (K, C) pair from the buffer
      */
-    def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = {
+    private def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = {
       var i = 0
       while (i < buffer.pairs.size) {
         val (k, c) = buffer.pairs(i)
@@ -293,7 +279,8 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
     override def hasNext: Boolean = mergeHeap.exists(!_.pairs.isEmpty)
 
     /**
-     * Select a key with the minimum hash, then combine all values with the same key from all input streams.
+     * Select a key with the minimum hash, then combine all values with the same key from all
+     * input streams
      */
     override def next(): (K, C) = {
       // Select a key from the StreamBuffer that holds the lowest key hash
@@ -333,7 +320,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
      *
      * StreamBuffers are ordered by the minimum key hash found across all of their own pairs.
      */
-    case class StreamBuffer(iterator: Iterator[(K, C)], pairs: ArrayBuffer[(K, C)])
+    private case class StreamBuffer(iterator: Iterator[(K, C)], pairs: ArrayBuffer[(K, C)])
       extends Comparable[StreamBuffer] {
 
       def minKeyHash: Int = {
@@ -355,51 +342,53 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
   /**
    * An iterator that returns (K, C) pairs in sorted order from an on-disk map
    */
-  private class DiskMapIterator(file: File, blockId: BlockId) extends Iterator[(K, C)] {
-    val fileStream = new FileInputStream(file)
-    val bufferedStream = new FastBufferedInputStream(fileStream, fileBufferSize)
-
-    val shouldCompress = blockManager.shouldCompress(blockId)
-    val compressionCodec = new LZFCompressionCodec(sparkConf)
-    val compressedStream =
-      if (shouldCompress) {
-        compressionCodec.compressedInputStream(bufferedStream)
+  private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
+    extends Iterator[(K, C)] {
+    private val fileStream = new FileInputStream(file)
+    private val bufferedStream = new FastBufferedInputStream(fileStream, fileBufferSize)
+
+    // An intermediate stream that reads from exactly one batch
+    // This guards against pre-fetching and other arbitrary behavior of higher level streams
+    private var batchStream = nextBatchStream()
+    private var compressedStream = blockManager.wrapForCompression(blockId, batchStream)
+    private var deserializeStream = ser.deserializeStream(compressedStream)
+    private var nextItem: (K, C) = null
+    private var objectsRead = 0
+
+    /**
+     * Construct a stream that reads only from the next batch
+     */
+    private def nextBatchStream(): InputStream = {
+      if (batchSizes.length > 0) {
+        ByteStreams.limit(bufferedStream, batchSizes.remove(0))
       } else {
+        // No more batches left
         bufferedStream
       }
-    var deserializeStream = ser.deserializeStream(compressedStream)
-    var objectsRead = 0
-
-    var nextItem: (K, C) = null
-    var eof = false
-
-    def readNextItem(): (K, C) = {
-      if (!eof) {
-        try {
-          if (objectsRead == serializerBatchSize) {
-            val newInputStream = deserializeStream match {
-              case stream: KryoDeserializationStream =>
-                // Kryo's serializer stores an internal buffer that pre-fetches from the underlying
-                // stream. We need to capture this buffer and feed it to the new serialization
-                // stream so that the bytes are not lost.
-                val kryoInput = stream.input
-                val remainingBytes = kryoInput.limit() - kryoInput.position()
-                val extraBuf = kryoInput.readBytes(remainingBytes)
-                new SequenceInputStream(new ByteArrayInputStream(extraBuf), compressedStream)
-              case _ => compressedStream
-            }
-            deserializeStream = ser.deserializeStream(newInputStream)
-            objectsRead = 0
-          }
-          objectsRead += 1
-          return deserializeStream.readObject().asInstanceOf[(K, C)]
-        } catch {
-          case e: EOFException =>
-            eof = true
-            cleanup()
+    }
+
+    /**
+     * Return the next (K, C) pair from the deserialization stream.
+     *
+     * If the current batch is drained, construct a stream for the next batch and read from it.
+     * If no more pairs are left, return null.
+     */
+    private def readNextItem(): (K, C) = {
+      try {
+        val item = deserializeStream.readObject().asInstanceOf[(K, C)]
+        objectsRead += 1
+        if (objectsRead == serializerBatchSize) {
+          batchStream = nextBatchStream()
+          compressedStream = blockManager.wrapForCompression(blockId, batchStream)
+          deserializeStream = ser.deserializeStream(compressedStream)
+          objectsRead = 0
         }
+        item
+      } catch {
+        case e: EOFException =>
+          cleanup()
+          null
       }
-      null
     }
 
     override def hasNext: Boolean = {
@@ -419,7 +408,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
     }
 
     // TODO: Ensure this gets called even if the iterator isn't drained.
-    def cleanup() {
+    private def cleanup() {
       deserializeStream.close()
       file.delete()
     }
diff --git a/docs/configuration.md b/docs/configuration.md
index 1f9fa7056697edda0aef7d060b3b7340c069d405..8e4c48c81f8beed5648e9d138f41d7678a186503 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -158,9 +158,7 @@ Apart from these, the following properties are also available, and may be useful
   <td>spark.shuffle.spill.compress</td>
   <td>true</td>
   <td>
-    Whether to compress data spilled during shuffles. If enabled, spill compression
-    always uses the `org.apache.spark.io.LZFCompressionCodec` codec, 
-    regardless of the value of `spark.io.compression.codec`.
+    Whether to compress data spilled during shuffles.
   </td>
 </tr>
 <tr>