diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index f235c434be7b1f86d5e8ed440d0536e9e2544d4e..8a1771848dee6117d92e32d0fa588c3fda6da2e4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -40,6 +40,8 @@ import org.apache.spark.annotation.Private; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; +import org.apache.commons.io.output.CloseShieldOutputStream; +import org.apache.commons.io.output.CountingOutputStream; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -264,6 +266,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); final boolean fastMergeIsSupported = !compressionEnabled || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); + final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file @@ -289,7 +292,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { // Compression is disabled or we are using an IO compression codec that supports // decompression of concatenated compressed streams, so we can perform a fast spill merge // that doesn't need to interpret the spilled bytes. - if (transferToEnabled) { + if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); } else { @@ -320,9 +323,9 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { /** * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in - * cases where the IO compression codec does not support concatenation of compressed data, or in - * cases where users have explicitly disabled use of {@code transferTo} in order to work around - * kernel bugs. + * cases where the IO compression codec does not support concatenation of compressed data, when + * encryption is enabled, or when users have explicitly disabled use of {@code transferTo} in + * order to work around kernel bugs. * * @param spills the spills to merge. * @param outputFile the file to write the merged data to. @@ -337,7 +340,11 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new FileInputStream[spills.length]; - OutputStream mergedFileOutputStream = null; + + // Use a counting output stream to avoid having to close the underlying file and ask + // the file system for its size after each partition is written. + final CountingOutputStream mergedFileOutputStream = new CountingOutputStream( + new FileOutputStream(outputFile)); boolean threwException = true; try { @@ -345,34 +352,35 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { spillInputStreams[i] = new FileInputStream(spills[i].file); } for (int partition = 0; partition < numPartitions; partition++) { - final long initialFileLength = outputFile.length(); - mergedFileOutputStream = - new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); + final long initialFileLength = mergedFileOutputStream.getByteCount(); + // Shield the underlying output stream from close() calls, so that we can close the higher + // level streams to make sure all data is really flushed and internal state is cleaned. + OutputStream partitionOutput = new CloseShieldOutputStream( + new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { - mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); + partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); } - for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = null; - boolean innerThrewException = true; + InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); try { - partitionInputStream = - new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false); + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); if (compressionCodec != null) { partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); } - ByteStreams.copy(partitionInputStream, mergedFileOutputStream); - innerThrewException = false; + ByteStreams.copy(partitionInputStream, partitionOutput); } finally { - Closeables.close(partitionInputStream, innerThrewException); + partitionInputStream.close(); } } } - mergedFileOutputStream.flush(); - mergedFileOutputStream.close(); - partitionLengths[partition] = (outputFile.length() - initialFileLength); + partitionOutput.flush(); + partitionOutput.close(); + partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); } threwException = false; } finally { diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 7371f886575c62ac239a9282e1f42399cd5a2f17..686305e9335dc362aa210f3d41c2cbf62791ed55 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -75,6 +75,8 @@ private[spark] class SerializerManager( * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) + def encryptionEnabled: Boolean = encryptionKey.isDefined + def canUseKryo(ct: ClassTag[_]): Boolean = { primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag } @@ -129,7 +131,7 @@ private[spark] class SerializerManager( /** * Wrap an input stream for encryption if shuffle encryption is enabled */ - private[this] def wrapForEncryption(s: InputStream): InputStream = { + def wrapForEncryption(s: InputStream): InputStream = { encryptionKey .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) } .getOrElse(s) @@ -138,7 +140,7 @@ private[spark] class SerializerManager( /** * Wrap an output stream for encryption if shuffle encryption is enabled */ - private[this] def wrapForEncryption(s: OutputStream): OutputStream = { + def wrapForEncryption(s: OutputStream): OutputStream = { encryptionKey .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) } .getOrElse(s) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 982b83324e0fce46b8377faf3a142c986f221192..04521c9159eacf97931367e20d1616959bf54e11 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -62,7 +62,7 @@ private[spark] class BlockManager( executorId: String, rpcEnv: RpcEnv, val master: BlockManagerMaster, - serializerManager: SerializerManager, + val serializerManager: SerializerManager, val conf: SparkConf, memoryManager: MemoryManager, mapOutputTracker: MapOutputTracker, @@ -745,9 +745,8 @@ private[spark] class BlockManager( serializerInstance: SerializerInstance, bufferSize: Int, writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { - val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream, + new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize, syncWrites, writeMetrics, blockId) } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index a499827ae159890c32659668650d15f203878899..3cb12fca7dccb1bb306d1a4fd64786b233c6e106 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -22,7 +22,7 @@ import java.nio.channels.FileChannel import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging -import org.apache.spark.serializer.{SerializationStream, SerializerInstance} +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.util.Utils /** @@ -37,9 +37,9 @@ import org.apache.spark.util.Utils */ private[spark] class DiskBlockObjectWriter( val file: File, + serializerManager: SerializerManager, serializerInstance: SerializerInstance, bufferSize: Int, - wrapStream: OutputStream => OutputStream, syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. @@ -116,7 +116,7 @@ private[spark] class DiskBlockObjectWriter( initialized = true } - bs = wrapStream(mcs) + bs = serializerManager.wrapStream(blockId, mcs) objOut = serializerInstance.serializeStream(bs) streamOpen = true this diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index a96cd82382e2cd06d3ed0018d07d1fee9ca88536..088b68132d9053606c03771a0cd9f67903f4e11b 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -26,11 +26,9 @@ import scala.Product2; import scala.Tuple2; import scala.Tuple2$; import scala.collection.Iterator; -import scala.runtime.AbstractFunction1; import com.google.common.collect.HashMultiset; import com.google.common.collect.Iterators; -import com.google.common.io.ByteStreams; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -53,6 +51,7 @@ import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.serializer.*; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.storage.*; @@ -77,7 +76,6 @@ public class UnsafeShuffleWriterSuite { final LinkedList<File> spillFilesCreated = new LinkedList<>(); SparkConf conf; final Serializer serializer = new KryoSerializer(new SparkConf()); - final SerializerManager serializerManager = new SerializerManager(serializer, new SparkConf()); TaskMetrics taskMetrics; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @@ -86,17 +84,6 @@ public class UnsafeShuffleWriterSuite { @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep; - private final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> { - @Override - public OutputStream apply(OutputStream stream) { - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream); - } else { - return stream; - } - } - } - @After public void tearDown() { Utils.deleteRecursively(tempDir); @@ -121,6 +108,11 @@ public class UnsafeShuffleWriterSuite { memoryManager = new TestMemoryManager(conf); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + // Some tests will override this manager because they change the configuration. This is a + // default for tests that don't need a specific one. + SerializerManager manager = new SerializerManager(serializer, conf); + when(blockManager.serializerManager()).thenReturn(manager); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( any(BlockId.class), @@ -131,12 +123,11 @@ public class UnsafeShuffleWriterSuite { @Override public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { Object[] args = invocationOnMock.getArguments(); - return new DiskBlockObjectWriter( (File) args[1], + blockManager.serializerManager(), (SerializerInstance) args[2], (Integer) args[3], - new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] @@ -201,9 +192,10 @@ public class UnsafeShuffleWriterSuite { for (int i = 0; i < NUM_PARTITITONS; i++) { final long partitionSize = partitionSizesInMergedFile[i]; if (partitionSize > 0) { - InputStream in = new FileInputStream(mergedOutputFile); - ByteStreams.skipFully(in, startOffset); - in = new LimitedInputStream(in, partitionSize); + FileInputStream fin = new FileInputStream(mergedOutputFile); + fin.getChannel().position(startOffset); + InputStream in = new LimitedInputStream(fin, partitionSize); + in = blockManager.serializerManager().wrapForEncryption(in); if (conf.getBoolean("spark.shuffle.compress", true)) { in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); } @@ -294,14 +286,32 @@ public class UnsafeShuffleWriterSuite { } private void testMergingSpills( - boolean transferToEnabled, - String compressionCodecName) throws IOException { + final boolean transferToEnabled, + String compressionCodecName, + boolean encrypt) throws Exception { if (compressionCodecName != null) { conf.set("spark.shuffle.compress", "true"); conf.set("spark.io.compression.codec", compressionCodecName); } else { conf.set("spark.shuffle.compress", "false"); } + conf.set(org.apache.spark.internal.config.package$.MODULE$.IO_ENCRYPTION_ENABLED(), encrypt); + + SerializerManager manager; + if (encrypt) { + manager = new SerializerManager(serializer, conf, + Option.apply(CryptoStreamUtils.createKey(conf))); + } else { + manager = new SerializerManager(serializer, conf); + } + + when(blockManager.serializerManager()).thenReturn(manager); + testMergingSpills(transferToEnabled, encrypt); + } + + private void testMergingSpills( + boolean transferToEnabled, + boolean encrypted) throws IOException { final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled); final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>(); for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { @@ -324,6 +334,7 @@ public class UnsafeShuffleWriterSuite { for (long size: partitionSizesInMergedFile) { sumOfPartitionSizes += size; } + assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); @@ -338,42 +349,72 @@ public class UnsafeShuffleWriterSuite { @Test public void mergeSpillsWithTransferToAndLZF() throws Exception { - testMergingSpills(true, LZFCompressionCodec.class.getName()); + testMergingSpills(true, LZFCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithFileStreamAndLZF() throws Exception { - testMergingSpills(false, LZFCompressionCodec.class.getName()); + testMergingSpills(false, LZFCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithTransferToAndLZ4() throws Exception { - testMergingSpills(true, LZ4CompressionCodec.class.getName()); + testMergingSpills(true, LZ4CompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithFileStreamAndLZ4() throws Exception { - testMergingSpills(false, LZ4CompressionCodec.class.getName()); + testMergingSpills(false, LZ4CompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithTransferToAndSnappy() throws Exception { - testMergingSpills(true, SnappyCompressionCodec.class.getName()); + testMergingSpills(true, SnappyCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithFileStreamAndSnappy() throws Exception { - testMergingSpills(false, SnappyCompressionCodec.class.getName()); + testMergingSpills(false, SnappyCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithTransferToAndNoCompression() throws Exception { - testMergingSpills(true, null); + testMergingSpills(true, null, false); } @Test public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { - testMergingSpills(false, null); + testMergingSpills(false, null, false); + } + + @Test + public void mergeSpillsWithCompressionAndEncryption() throws Exception { + // This should actually be translated to a "file stream merge" internally, just have the + // test to make sure that it's the case. + testMergingSpills(true, LZ4CompressionCodec.class.getName(), true); + } + + @Test + public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Exception { + testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); + } + + @Test + public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception { + conf.set("spark.shuffle.unsafe.fastMergeEnabled", "false"); + testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); + } + + @Test + public void mergeSpillsWithEncryptionAndNoCompression() throws Exception { + // This should actually be translated to a "file stream merge" internally, just have the + // test to make sure that it's the case. + testMergingSpills(true, null, true); + } + + @Test + public void mergeSpillsWithFileStreamAndEncryptionAndNoCompression() throws Exception { + testMergingSpills(false, null, true); } @Test @@ -531,4 +572,5 @@ public class UnsafeShuffleWriterSuite { writer.stop(false); } } + } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 33709b454c4c98044895c213179030a6d7c781d1..26568146bf4d7dcf4601735a0bf5bd292a96ba1e 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -19,13 +19,11 @@ package org.apache.spark.unsafe.map; import java.io.File; import java.io.IOException; -import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.*; import scala.Tuple2; import scala.Tuple2$; -import scala.runtime.AbstractFunction1; import org.junit.After; import org.junit.Assert; @@ -75,13 +73,6 @@ public abstract class AbstractBytesToBytesMapSuite { @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; - private static final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> { - @Override - public OutputStream apply(OutputStream stream) { - return stream; - } - } - @Before public void setup() { memoryManager = @@ -120,9 +111,9 @@ public abstract class AbstractBytesToBytesMapSuite { return new DiskBlockObjectWriter( (File) args[1], + serializerManager, (SerializerInstance) args[2], (Integer) args[3], - new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index a9cf8ff520ed432bfc5033d37b0ce15aa13d1071..fbbe530a132e18bc0f61742f0d4a2a401e578cec 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -19,14 +19,12 @@ package org.apache.spark.util.collection.unsafe.sort; import java.io.File; import java.io.IOException; -import java.io.OutputStream; import java.util.Arrays; import java.util.LinkedList; import java.util.UUID; import scala.Tuple2; import scala.Tuple2$; -import scala.runtime.AbstractFunction1; import org.junit.After; import org.junit.Before; @@ -57,13 +55,15 @@ import static org.mockito.Mockito.*; public class UnsafeExternalSorterSuite { + private final SparkConf conf = new SparkConf(); + final LinkedList<File> spillFilesCreated = new LinkedList<>(); final TestMemoryManager memoryManager = - new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); + new TestMemoryManager(conf.clone().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); final SerializerManager serializerManager = new SerializerManager( - new JavaSerializer(new SparkConf()), - new SparkConf().set("spark.shuffle.spill.compress", "false")); + new JavaSerializer(conf), + conf.clone().set("spark.shuffle.spill.compress", "false")); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = PrefixComparators.LONG; // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so @@ -86,14 +86,7 @@ public class UnsafeExternalSorterSuite { protected boolean shouldUseRadixSort() { return false; } - private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m"); - - private static final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> { - @Override - public OutputStream apply(OutputStream stream) { - return stream; - } - } + private final long pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "4m"); @Before public void setUp() { @@ -126,9 +119,9 @@ public class UnsafeExternalSorterSuite { return new DiskBlockObjectWriter( (File) args[1], + serializerManager, (SerializerInstance) args[2], (Integer) args[3], - new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 442941685f1ae2fc6860071bf1dd955d8013f322..85ccb3347104805e57f8cd90f5be2dba6753ee4a 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} -import org.apache.spark.serializer.{JavaSerializer, SerializerInstance} +import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -90,11 +90,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte )).thenAnswer(new Answer[DiskBlockObjectWriter] { override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { val args = invocation.getArguments + val manager = new SerializerManager(new JavaSerializer(conf), conf) new DiskBlockObjectWriter( args(1).asInstanceOf[File], + manager, args(2).asInstanceOf[SerializerInstance], args(3).asInstanceOf[Int], - wrapStream = identity, syncWrites = false, args(4).asInstanceOf[ShuffleWriteMetrics], blockId = args(0).asInstanceOf[BlockId] diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 684e978d11864833c1366e9be5a274fb667c5df0..bfb3ac4c15bcae8629222a7d39fdd34390f3f24a 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.util.Utils class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -42,11 +42,19 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("verify write metrics") { + private def createWriter(): (DiskBlockObjectWriter, File, ShuffleWriteMetrics) = { val file = new File(tempDir, "somefile") + val conf = new SparkConf() + val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + file, serializerManager, new JavaSerializer(new SparkConf()).newInstance(), 1024, true, + writeMetrics) + (writer, file, writeMetrics) + } + + test("verify write metrics") { + val (writer, file, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write @@ -66,10 +74,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("verify write metrics on revert") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, _, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write @@ -89,10 +94,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("Reopening a closed block writer") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, _, _) = createWriter() writer.open() writer.close() @@ -102,10 +104,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("calling revertPartialWritesAndClose() on a partial write should truncate up to commit") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) val firstSegment = writer.commitAndGet() @@ -120,10 +119,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("calling revertPartialWritesAndClose() after commit() should have no effect") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) val firstSegment = writer.commitAndGet() @@ -136,10 +132,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() for (i <- 1 to 1000) { writer.write(i, i) } @@ -153,10 +146,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("commit() and close() should be idempotent") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() for (i <- 1 to 1000) { writer.write(i, i) } @@ -173,10 +163,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("revertPartialWritesAndClose() should be idempotent") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() for (i <- 1 to 1000) { writer.write(i, i) } @@ -191,10 +178,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("commit() and close() without ever opening or writing") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, _, _) = createWriter() val segment = writer.commitAndGet() writer.close() assert(segment.length === 0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 5141e36d9e38d2848cad9790683775e5eada9f35..7f0838268a11164fc51f7870908d7977ffe52eda 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer import org.apache.spark._ +import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.memory.MemoryTestingUtils @@ -230,14 +231,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { } } + test("spilling with compression and encryption") { + testSimpleSpilling(Some(CompressionCodec.DEFAULT_COMPRESSION_CODEC), encrypt = true) + } + /** * Test spilling through simple aggregations and cogroups. * If a compression codec is provided, use it. Otherwise, do not compress spills. */ - private def testSimpleSpilling(codec: Option[String] = None): Unit = { + private def testSimpleSpilling(codec: Option[String] = None, encrypt: Boolean = false): Unit = { val size = 1000 val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) + conf.set(IO_ENCRYPTION_ENABLED, encrypt) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) assertSpilled(sc, "reduceByKey") {