diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index f235c434be7b1f86d5e8ed440d0536e9e2544d4e..8a1771848dee6117d92e32d0fa588c3fda6da2e4 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -40,6 +40,8 @@ import org.apache.spark.annotation.Private;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.io.CompressionCodec;
 import org.apache.spark.io.CompressionCodec$;
+import org.apache.commons.io.output.CloseShieldOutputStream;
+import org.apache.commons.io.output.CountingOutputStream;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.network.util.LimitedInputStream;
 import org.apache.spark.scheduler.MapStatus;
@@ -264,6 +266,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
       sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
     final boolean fastMergeIsSupported = !compressionEnabled ||
       CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
+    final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled();
     try {
       if (spills.length == 0) {
         new FileOutputStream(outputFile).close(); // Create an empty file
@@ -289,7 +292,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
           // Compression is disabled or we are using an IO compression codec that supports
           // decompression of concatenated compressed streams, so we can perform a fast spill merge
           // that doesn't need to interpret the spilled bytes.
-          if (transferToEnabled) {
+          if (transferToEnabled && !encryptionEnabled) {
             logger.debug("Using transferTo-based fast merge");
             partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
           } else {
@@ -320,9 +323,9 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
   /**
    * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
    * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
-   * cases where the IO compression codec does not support concatenation of compressed data, or in
-   * cases where users have explicitly disabled use of {@code transferTo} in order to work around
-   * kernel bugs.
+   * cases where the IO compression codec does not support concatenation of compressed data, when
+   * encryption is enabled, or when users have explicitly disabled use of {@code transferTo} in
+   * order to work around kernel bugs.
    *
    * @param spills the spills to merge.
    * @param outputFile the file to write the merged data to.
@@ -337,7 +340,11 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     final int numPartitions = partitioner.numPartitions();
     final long[] partitionLengths = new long[numPartitions];
     final InputStream[] spillInputStreams = new FileInputStream[spills.length];
-    OutputStream mergedFileOutputStream = null;
+
+    // Use a counting output stream to avoid having to close the underlying file and ask
+    // the file system for its size after each partition is written.
+    final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(
+      new FileOutputStream(outputFile));
 
     boolean threwException = true;
     try {
@@ -345,34 +352,35 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
         spillInputStreams[i] = new FileInputStream(spills[i].file);
       }
       for (int partition = 0; partition < numPartitions; partition++) {
-        final long initialFileLength = outputFile.length();
-        mergedFileOutputStream =
-          new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
+        final long initialFileLength = mergedFileOutputStream.getByteCount();
+        // Shield the underlying output stream from close() calls, so that we can close the higher
+        // level streams to make sure all data is really flushed and internal state is cleaned.
+        OutputStream partitionOutput = new CloseShieldOutputStream(
+          new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream));
+        partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
         if (compressionCodec != null) {
-          mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
+          partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
         }
-
         for (int i = 0; i < spills.length; i++) {
           final long partitionLengthInSpill = spills[i].partitionLengths[partition];
           if (partitionLengthInSpill > 0) {
-            InputStream partitionInputStream = null;
-            boolean innerThrewException = true;
+            InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i],
+              partitionLengthInSpill, false);
             try {
-              partitionInputStream =
-                  new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false);
+              partitionInputStream = blockManager.serializerManager().wrapForEncryption(
+                partitionInputStream);
               if (compressionCodec != null) {
                 partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
               }
-              ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
-              innerThrewException = false;
+              ByteStreams.copy(partitionInputStream, partitionOutput);
             } finally {
-              Closeables.close(partitionInputStream, innerThrewException);
+              partitionInputStream.close();
             }
           }
         }
-        mergedFileOutputStream.flush();
-        mergedFileOutputStream.close();
-        partitionLengths[partition] = (outputFile.length() - initialFileLength);
+        partitionOutput.flush();
+        partitionOutput.close();
+        partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength);
       }
       threwException = false;
     } finally {
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 7371f886575c62ac239a9282e1f42399cd5a2f17..686305e9335dc362aa210f3d41c2cbf62791ed55 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -75,6 +75,8 @@ private[spark] class SerializerManager(
    * loaded yet. */
   private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)
 
+  def encryptionEnabled: Boolean = encryptionKey.isDefined
+
   def canUseKryo(ct: ClassTag[_]): Boolean = {
     primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag
   }
@@ -129,7 +131,7 @@ private[spark] class SerializerManager(
   /**
    * Wrap an input stream for encryption if shuffle encryption is enabled
    */
-  private[this] def wrapForEncryption(s: InputStream): InputStream = {
+  def wrapForEncryption(s: InputStream): InputStream = {
     encryptionKey
       .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) }
       .getOrElse(s)
@@ -138,7 +140,7 @@ private[spark] class SerializerManager(
   /**
    * Wrap an output stream for encryption if shuffle encryption is enabled
    */
-  private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
+  def wrapForEncryption(s: OutputStream): OutputStream = {
     encryptionKey
       .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) }
       .getOrElse(s)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 982b83324e0fce46b8377faf3a142c986f221192..04521c9159eacf97931367e20d1616959bf54e11 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -62,7 +62,7 @@ private[spark] class BlockManager(
     executorId: String,
     rpcEnv: RpcEnv,
     val master: BlockManagerMaster,
-    serializerManager: SerializerManager,
+    val serializerManager: SerializerManager,
     val conf: SparkConf,
     memoryManager: MemoryManager,
     mapOutputTracker: MapOutputTracker,
@@ -745,9 +745,8 @@ private[spark] class BlockManager(
       serializerInstance: SerializerInstance,
       bufferSize: Int,
       writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
-    val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _)
     val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
-    new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream,
+    new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize,
       syncWrites, writeMetrics, blockId)
   }
 
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index a499827ae159890c32659668650d15f203878899..3cb12fca7dccb1bb306d1a4fd64786b233c6e106 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -22,7 +22,7 @@ import java.nio.channels.FileChannel
 
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.internal.Logging
-import org.apache.spark.serializer.{SerializationStream, SerializerInstance}
+import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
 import org.apache.spark.util.Utils
 
 /**
@@ -37,9 +37,9 @@ import org.apache.spark.util.Utils
  */
 private[spark] class DiskBlockObjectWriter(
     val file: File,
+    serializerManager: SerializerManager,
     serializerInstance: SerializerInstance,
     bufferSize: Int,
-    wrapStream: OutputStream => OutputStream,
     syncWrites: Boolean,
     // These write metrics concurrently shared with other active DiskBlockObjectWriters who
     // are themselves performing writes. All updates must be relative.
@@ -116,7 +116,7 @@ private[spark] class DiskBlockObjectWriter(
       initialized = true
     }
 
-    bs = wrapStream(mcs)
+    bs = serializerManager.wrapStream(blockId, mcs)
     objOut = serializerInstance.serializeStream(bs)
     streamOpen = true
     this
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index a96cd82382e2cd06d3ed0018d07d1fee9ca88536..088b68132d9053606c03771a0cd9f67903f4e11b 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -26,11 +26,9 @@ import scala.Product2;
 import scala.Tuple2;
 import scala.Tuple2$;
 import scala.collection.Iterator;
-import scala.runtime.AbstractFunction1;
 
 import com.google.common.collect.HashMultiset;
 import com.google.common.collect.Iterators;
-import com.google.common.io.ByteStreams;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -53,6 +51,7 @@ import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.network.util.LimitedInputStream;
 import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.security.CryptoStreamUtils;
 import org.apache.spark.serializer.*;
 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
 import org.apache.spark.storage.*;
@@ -77,7 +76,6 @@ public class UnsafeShuffleWriterSuite {
   final LinkedList<File> spillFilesCreated = new LinkedList<>();
   SparkConf conf;
   final Serializer serializer = new KryoSerializer(new SparkConf());
-  final SerializerManager serializerManager = new SerializerManager(serializer, new SparkConf());
   TaskMetrics taskMetrics;
 
   @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
@@ -86,17 +84,6 @@ public class UnsafeShuffleWriterSuite {
   @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
   @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;
 
-  private final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> {
-    @Override
-    public OutputStream apply(OutputStream stream) {
-      if (conf.getBoolean("spark.shuffle.compress", true)) {
-        return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
-      } else {
-        return stream;
-      }
-    }
-  }
-
   @After
   public void tearDown() {
     Utils.deleteRecursively(tempDir);
@@ -121,6 +108,11 @@ public class UnsafeShuffleWriterSuite {
     memoryManager = new TestMemoryManager(conf);
     taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
 
+    // Some tests will override this manager because they change the configuration. This is a
+    // default for tests that don't need a specific one.
+    SerializerManager manager = new SerializerManager(serializer, conf);
+    when(blockManager.serializerManager()).thenReturn(manager);
+
     when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
     when(blockManager.getDiskWriter(
       any(BlockId.class),
@@ -131,12 +123,11 @@ public class UnsafeShuffleWriterSuite {
       @Override
       public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
         Object[] args = invocationOnMock.getArguments();
-
         return new DiskBlockObjectWriter(
           (File) args[1],
+          blockManager.serializerManager(),
           (SerializerInstance) args[2],
           (Integer) args[3],
-          new WrapStream(),
           false,
           (ShuffleWriteMetrics) args[4],
           (BlockId) args[0]
@@ -201,9 +192,10 @@ public class UnsafeShuffleWriterSuite {
     for (int i = 0; i < NUM_PARTITITONS; i++) {
       final long partitionSize = partitionSizesInMergedFile[i];
       if (partitionSize > 0) {
-        InputStream in = new FileInputStream(mergedOutputFile);
-        ByteStreams.skipFully(in, startOffset);
-        in = new LimitedInputStream(in, partitionSize);
+        FileInputStream fin = new FileInputStream(mergedOutputFile);
+        fin.getChannel().position(startOffset);
+        InputStream in = new LimitedInputStream(fin, partitionSize);
+        in = blockManager.serializerManager().wrapForEncryption(in);
         if (conf.getBoolean("spark.shuffle.compress", true)) {
           in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
         }
@@ -294,14 +286,32 @@ public class UnsafeShuffleWriterSuite {
   }
 
   private void testMergingSpills(
-      boolean transferToEnabled,
-      String compressionCodecName) throws IOException {
+      final boolean transferToEnabled,
+      String compressionCodecName,
+      boolean encrypt) throws Exception {
     if (compressionCodecName != null) {
       conf.set("spark.shuffle.compress", "true");
       conf.set("spark.io.compression.codec", compressionCodecName);
     } else {
       conf.set("spark.shuffle.compress", "false");
     }
+    conf.set(org.apache.spark.internal.config.package$.MODULE$.IO_ENCRYPTION_ENABLED(), encrypt);
+
+    SerializerManager manager;
+    if (encrypt) {
+      manager = new SerializerManager(serializer, conf,
+        Option.apply(CryptoStreamUtils.createKey(conf)));
+    } else {
+      manager = new SerializerManager(serializer, conf);
+    }
+
+    when(blockManager.serializerManager()).thenReturn(manager);
+    testMergingSpills(transferToEnabled, encrypt);
+  }
+
+  private void testMergingSpills(
+      boolean transferToEnabled,
+      boolean encrypted) throws IOException {
     final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
     for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
@@ -324,6 +334,7 @@ public class UnsafeShuffleWriterSuite {
     for (long size: partitionSizesInMergedFile) {
       sumOfPartitionSizes += size;
     }
+
     assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
 
     assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile()));
@@ -338,42 +349,72 @@ public class UnsafeShuffleWriterSuite {
 
   @Test
   public void mergeSpillsWithTransferToAndLZF() throws Exception {
-    testMergingSpills(true, LZFCompressionCodec.class.getName());
+    testMergingSpills(true, LZFCompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithFileStreamAndLZF() throws Exception {
-    testMergingSpills(false, LZFCompressionCodec.class.getName());
+    testMergingSpills(false, LZFCompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithTransferToAndLZ4() throws Exception {
-    testMergingSpills(true, LZ4CompressionCodec.class.getName());
+    testMergingSpills(true, LZ4CompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
-    testMergingSpills(false, LZ4CompressionCodec.class.getName());
+    testMergingSpills(false, LZ4CompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithTransferToAndSnappy() throws Exception {
-    testMergingSpills(true, SnappyCompressionCodec.class.getName());
+    testMergingSpills(true, SnappyCompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
-    testMergingSpills(false, SnappyCompressionCodec.class.getName());
+    testMergingSpills(false, SnappyCompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
-    testMergingSpills(true, null);
+    testMergingSpills(true, null, false);
   }
 
   @Test
   public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
-    testMergingSpills(false, null);
+    testMergingSpills(false, null, false);
+  }
+
+  @Test
+  public void mergeSpillsWithCompressionAndEncryption() throws Exception {
+    // This should actually be translated to a "file stream merge" internally, just have the
+    // test to make sure that it's the case.
+    testMergingSpills(true, LZ4CompressionCodec.class.getName(), true);
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Exception {
+    testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
+  }
+
+  @Test
+  public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception {
+    conf.set("spark.shuffle.unsafe.fastMergeEnabled", "false");
+    testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
+  }
+
+  @Test
+  public void mergeSpillsWithEncryptionAndNoCompression() throws Exception {
+    // This should actually be translated to a "file stream merge" internally, just have the
+    // test to make sure that it's the case.
+    testMergingSpills(true, null, true);
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndEncryptionAndNoCompression() throws Exception {
+    testMergingSpills(false, null, true);
   }
 
   @Test
@@ -531,4 +572,5 @@ public class UnsafeShuffleWriterSuite {
       writer.stop(false);
     }
   }
+
 }
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 33709b454c4c98044895c213179030a6d7c781d1..26568146bf4d7dcf4601735a0bf5bd292a96ba1e 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -19,13 +19,11 @@ package org.apache.spark.unsafe.map;
 
 import java.io.File;
 import java.io.IOException;
-import java.io.OutputStream;
 import java.nio.ByteBuffer;
 import java.util.*;
 
 import scala.Tuple2;
 import scala.Tuple2$;
-import scala.runtime.AbstractFunction1;
 
 import org.junit.After;
 import org.junit.Assert;
@@ -75,13 +73,6 @@ public abstract class AbstractBytesToBytesMapSuite {
   @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
   @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
 
-  private static final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> {
-    @Override
-    public OutputStream apply(OutputStream stream) {
-      return stream;
-    }
-  }
-
   @Before
   public void setup() {
     memoryManager =
@@ -120,9 +111,9 @@ public abstract class AbstractBytesToBytesMapSuite {
 
         return new DiskBlockObjectWriter(
           (File) args[1],
+          serializerManager,
           (SerializerInstance) args[2],
           (Integer) args[3],
-          new WrapStream(),
           false,
           (ShuffleWriteMetrics) args[4],
           (BlockId) args[0]
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index a9cf8ff520ed432bfc5033d37b0ce15aa13d1071..fbbe530a132e18bc0f61742f0d4a2a401e578cec 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -19,14 +19,12 @@ package org.apache.spark.util.collection.unsafe.sort;
 
 import java.io.File;
 import java.io.IOException;
-import java.io.OutputStream;
 import java.util.Arrays;
 import java.util.LinkedList;
 import java.util.UUID;
 
 import scala.Tuple2;
 import scala.Tuple2$;
-import scala.runtime.AbstractFunction1;
 
 import org.junit.After;
 import org.junit.Before;
@@ -57,13 +55,15 @@ import static org.mockito.Mockito.*;
 
 public class UnsafeExternalSorterSuite {
 
+  private final SparkConf conf = new SparkConf();
+
   final LinkedList<File> spillFilesCreated = new LinkedList<>();
   final TestMemoryManager memoryManager =
-    new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false"));
+    new TestMemoryManager(conf.clone().set("spark.memory.offHeap.enabled", "false"));
   final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
   final SerializerManager serializerManager = new SerializerManager(
-    new JavaSerializer(new SparkConf()),
-    new SparkConf().set("spark.shuffle.spill.compress", "false"));
+    new JavaSerializer(conf),
+    conf.clone().set("spark.shuffle.spill.compress", "false"));
   // Use integer comparison for comparing prefixes (which are partition ids, in this case)
   final PrefixComparator prefixComparator = PrefixComparators.LONG;
   // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
@@ -86,14 +86,7 @@ public class UnsafeExternalSorterSuite {
 
   protected boolean shouldUseRadixSort() { return false; }
 
-  private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m");
-
-  private static final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> {
-    @Override
-    public OutputStream apply(OutputStream stream) {
-      return stream;
-    }
-  }
+  private final long pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "4m");
 
   @Before
   public void setUp() {
@@ -126,9 +119,9 @@ public class UnsafeExternalSorterSuite {
 
         return new DiskBlockObjectWriter(
           (File) args[1],
+          serializerManager,
           (SerializerInstance) args[2],
           (Integer) args[3],
-          new WrapStream(),
           false,
           (ShuffleWriteMetrics) args[4],
           (BlockId) args[0]
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 442941685f1ae2fc6860071bf1dd955d8013f322..85ccb3347104805e57f8cd90f5be2dba6753ee4a 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -33,7 +33,7 @@ import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark._
 import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics}
-import org.apache.spark.serializer.{JavaSerializer, SerializerInstance}
+import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
 import org.apache.spark.shuffle.IndexShuffleBlockResolver
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
@@ -90,11 +90,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
     )).thenAnswer(new Answer[DiskBlockObjectWriter] {
       override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = {
         val args = invocation.getArguments
+        val manager = new SerializerManager(new JavaSerializer(conf), conf)
         new DiskBlockObjectWriter(
           args(1).asInstanceOf[File],
+          manager,
           args(2).asInstanceOf[SerializerInstance],
           args(3).asInstanceOf[Int],
-          wrapStream = identity,
           syncWrites = false,
           args(4).asInstanceOf[ShuffleWriteMetrics],
           blockId = args(0).asInstanceOf[BlockId]
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
index 684e978d11864833c1366e9be5a274fb667c5df0..bfb3ac4c15bcae8629222a7d39fdd34390f3f24a 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
@@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
 import org.apache.spark.util.Utils
 
 class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
@@ -42,11 +42,19 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
     }
   }
 
-  test("verify write metrics") {
+  private def createWriter(): (DiskBlockObjectWriter, File, ShuffleWriteMetrics) = {
     val file = new File(tempDir, "somefile")
+    val conf = new SparkConf()
+    val serializerManager = new SerializerManager(new JavaSerializer(conf), conf)
     val writeMetrics = new ShuffleWriteMetrics()
     val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+      file, serializerManager, new JavaSerializer(new SparkConf()).newInstance(), 1024, true,
+      writeMetrics)
+    (writer, file, writeMetrics)
+  }
+
+  test("verify write metrics") {
+    val (writer, file, writeMetrics) = createWriter()
 
     writer.write(Long.box(20), Long.box(30))
     // Record metrics update on every write
@@ -66,10 +74,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("verify write metrics on revert") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, _, writeMetrics) = createWriter()
 
     writer.write(Long.box(20), Long.box(30))
     // Record metrics update on every write
@@ -89,10 +94,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("Reopening a closed block writer") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, _, _) = createWriter()
 
     writer.open()
     writer.close()
@@ -102,10 +104,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("calling revertPartialWritesAndClose() on a partial write should truncate up to commit") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
 
     writer.write(Long.box(20), Long.box(30))
     val firstSegment = writer.commitAndGet()
@@ -120,10 +119,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("calling revertPartialWritesAndClose() after commit() should have no effect") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
 
     writer.write(Long.box(20), Long.box(30))
     val firstSegment = writer.commitAndGet()
@@ -136,10 +132,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
     for (i <- 1 to 1000) {
       writer.write(i, i)
     }
@@ -153,10 +146,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("commit() and close() should be idempotent") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
     for (i <- 1 to 1000) {
       writer.write(i, i)
     }
@@ -173,10 +163,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("revertPartialWritesAndClose() should be idempotent") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
     for (i <- 1 to 1000) {
       writer.write(i, i)
     }
@@ -191,10 +178,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("commit() and close() without ever opening or writing") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, _, _) = createWriter()
     val segment = writer.commitAndGet()
     writer.close()
     assert(segment.length === 0)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 5141e36d9e38d2848cad9790683775e5eada9f35..7f0838268a11164fc51f7870908d7977ffe52eda 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.util.collection
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark._
+import org.apache.spark.internal.config._
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.memory.MemoryTestingUtils
 
@@ -230,14 +231,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
     }
   }
 
+  test("spilling with compression and encryption") {
+    testSimpleSpilling(Some(CompressionCodec.DEFAULT_COMPRESSION_CODEC), encrypt = true)
+  }
+
   /**
    * Test spilling through simple aggregations and cogroups.
    * If a compression codec is provided, use it. Otherwise, do not compress spills.
    */
-  private def testSimpleSpilling(codec: Option[String] = None): Unit = {
+  private def testSimpleSpilling(codec: Option[String] = None, encrypt: Boolean = false): Unit = {
     val size = 1000
     val conf = createSparkConf(loadDefaults = true, codec)  // Load defaults for Spark home
     conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString)
+    conf.set(IO_ENCRYPTION_ENABLED, encrypt)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
 
     assertSpilled(sc, "reduceByKey") {