diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index f3eaf22c0166e73ef38de8152972c39068ee2cc1..51d7fda0cb260cc0d40afacbfd00bb400b0e9de2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -18,9 +18,11 @@ package org.apache.spark.network.util; import java.io.Closeable; +import java.io.EOFException; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; @@ -344,4 +346,17 @@ public class JavaUtils { } } + /** + * Fills a buffer with data read from the channel. + */ + public static void readFully(ReadableByteChannel channel, ByteBuffer dst) throws IOException { + int expected = dst.remaining(); + while (dst.hasRemaining()) { + if (channel.read(dst) < 0) { + throw new EOFException(String.format("Not enough bytes in channel (expected %d).", + expected)); + } + } + } + } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 22d01c47e645d034c02fedf8ab49e6e7958d216f..039df75ce74fd6a7a0c6a243e09a4673a8c62d11 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -29,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel} +import org.apache.spark.storage._ import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -141,10 +141,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } /** Fetch torrent blocks from the driver and/or other executors. */ - private def readBlocks(): Array[ChunkedByteBuffer] = { + private def readBlocks(): Array[BlockData] = { // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported // to the driver, so other executors can pull these chunks from this executor as well. - val blocks = new Array[ChunkedByteBuffer](numBlocks) + val blocks = new Array[BlockData](numBlocks) val bm = SparkEnv.get.blockManager for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { @@ -173,7 +173,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) throw new SparkException( s"Failed to store $pieceId of $broadcastId in local BlockManager") } - blocks(pid) = b + blocks(pid) = new ByteBufferBlockData(b, true) case None => throw new SparkException(s"Failed to get $pieceId of $broadcastId") } @@ -219,18 +219,22 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) case None => logInfo("Started reading broadcast variable " + id) val startTimeMs = System.currentTimeMillis() - val blocks = readBlocks().flatMap(_.getChunks()) + val blocks = readBlocks() logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - val obj = TorrentBroadcast.unBlockifyObject[T]( - blocks, SparkEnv.get.serializer, compressionCodec) - // Store the merged copy in BlockManager so other tasks on this executor don't - // need to re-fetch it. - val storageLevel = StorageLevel.MEMORY_AND_DISK - if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") + try { + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + obj + } finally { + blocks.foreach(_.dispose()) } - obj } } } @@ -277,12 +281,11 @@ private object TorrentBroadcast extends Logging { } def unBlockifyObject[T: ClassTag]( - blocks: Array[ByteBuffer], + blocks: Array[InputStream], serializer: Serializer, compressionCodec: Option[CompressionCodec]): T = { require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") - val is = new SequenceInputStream( - blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration) + val is = new SequenceInputStream(blocks.iterator.asJavaEnumeration) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index cdd3b8d8512b1949ba054d109c1522eb5b53ad62..78dabb42ac9d252d66b4facc60837ee6f3b8427d 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -16,20 +16,23 @@ */ package org.apache.spark.security -import java.io.{InputStream, OutputStream} +import java.io.{EOFException, InputStream, OutputStream} +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, WritableByteChannel} import java.util.Properties import javax.crypto.KeyGenerator import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} import scala.collection.JavaConverters._ +import com.google.common.io.ByteStreams import org.apache.commons.crypto.random._ import org.apache.commons.crypto.stream._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ -import org.apache.spark.network.util.CryptoUtils +import org.apache.spark.network.util.{CryptoUtils, JavaUtils} /** * A util class for manipulating IO encryption and decryption streams. @@ -48,12 +51,27 @@ private[spark] object CryptoStreamUtils extends Logging { os: OutputStream, sparkConf: SparkConf, key: Array[Byte]): OutputStream = { - val properties = toCryptoConf(sparkConf) - val iv = createInitializationVector(properties) + val params = new CryptoParams(key, sparkConf) + val iv = createInitializationVector(params.conf) os.write(iv) - val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) - new CryptoOutputStream(transformationStr, properties, os, - new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec, + new IvParameterSpec(iv)) + } + + /** + * Wrap a `WritableByteChannel` for encryption. + */ + def createWritableChannel( + channel: WritableByteChannel, + sparkConf: SparkConf, + key: Array[Byte]): WritableByteChannel = { + val params = new CryptoParams(key, sparkConf) + val iv = createInitializationVector(params.conf) + val helper = new CryptoHelperChannel(channel) + + helper.write(ByteBuffer.wrap(iv)) + new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec, + new IvParameterSpec(iv)) } /** @@ -63,12 +81,27 @@ private[spark] object CryptoStreamUtils extends Logging { is: InputStream, sparkConf: SparkConf, key: Array[Byte]): InputStream = { - val properties = toCryptoConf(sparkConf) val iv = new Array[Byte](IV_LENGTH_IN_BYTES) - is.read(iv, 0, iv.length) - val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) - new CryptoInputStream(transformationStr, properties, is, - new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + ByteStreams.readFully(is, iv) + val params = new CryptoParams(key, sparkConf) + new CryptoInputStream(params.transformation, params.conf, is, params.keySpec, + new IvParameterSpec(iv)) + } + + /** + * Wrap a `ReadableByteChannel` for decryption. + */ + def createReadableChannel( + channel: ReadableByteChannel, + sparkConf: SparkConf, + key: Array[Byte]): ReadableByteChannel = { + val iv = new Array[Byte](IV_LENGTH_IN_BYTES) + val buf = ByteBuffer.wrap(iv) + JavaUtils.readFully(channel, buf) + + val params = new CryptoParams(key, sparkConf) + new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec, + new IvParameterSpec(iv)) } def toCryptoConf(conf: SparkConf): Properties = { @@ -102,4 +135,34 @@ private[spark] object CryptoStreamUtils extends Logging { } iv } + + /** + * This class is a workaround for CRYPTO-125, that forces all bytes to be written to the + * underlying channel. Since the callers of this API are using blocking I/O, there are no + * concerns with regards to CPU usage here. + */ + private class CryptoHelperChannel(sink: WritableByteChannel) extends WritableByteChannel { + + override def write(src: ByteBuffer): Int = { + val count = src.remaining() + while (src.hasRemaining()) { + sink.write(src) + } + count + } + + override def isOpen(): Boolean = sink.isOpen() + + override def close(): Unit = sink.close() + + } + + private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) { + + val keySpec = new SecretKeySpec(key, "AES") + val transformation = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) + val conf = toCryptoConf(sparkConf) + + } + } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 96b288b9cfb817d8601e46393d651c21245a428c..bb7ed8709ba8a9ebbd670c4f31d964c91ab3eb1b 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -148,14 +148,14 @@ private[spark] class SerializerManager( /** * Wrap an output stream for compression if block compression is enabled for its block type */ - private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } /** * Wrap an input stream for compression if block compression is enabled for its block type */ - private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } @@ -167,30 +167,26 @@ private[spark] class SerializerManager( val byteStream = new BufferedOutputStream(outputStream) val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = getSerializer(implicitly[ClassTag[T]], autoPick).newInstance() - ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() } /** Serializes into a chunked byte buffer. */ def dataSerialize[T: ClassTag]( blockId: BlockId, - values: Iterator[T], - allowEncryption: Boolean = true): ChunkedByteBuffer = { - dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]], - allowEncryption = allowEncryption) + values: Iterator[T]): ChunkedByteBuffer = { + dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]]) } /** Serializes into a chunked byte buffer. */ def dataSerializeWithExplicitClassTag( blockId: BlockId, values: Iterator[_], - classTag: ClassTag[_], - allowEncryption: Boolean = true): ChunkedByteBuffer = { + classTag: ClassTag[_]): ChunkedByteBuffer = { val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) val byteStream = new BufferedOutputStream(bbos) val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = getSerializer(classTag, autoPick).newInstance() - val encrypted = if (allowEncryption) wrapForEncryption(byteStream) else byteStream - ser.serializeStream(wrapForCompression(blockId, encrypted)).writeAll(values).close() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() bbos.toChunkedByteBuffer } @@ -200,15 +196,13 @@ private[spark] class SerializerManager( */ def dataDeserializeStream[T]( blockId: BlockId, - inputStream: InputStream, - maybeEncrypted: Boolean = true) + inputStream: InputStream) (classTag: ClassTag[T]): Iterator[T] = { val stream = new BufferedInputStream(inputStream) val autoPick = !blockId.isInstanceOf[StreamBlockId] - val decrypted = if (maybeEncrypted) wrapForEncryption(inputStream) else inputStream getSerializer(classTag, autoPick) .newInstance() - .deserializeStream(wrapForCompression(blockId, decrypted)) + .deserializeStream(wrapForCompression(blockId, inputStream)) .asIterator.asInstanceOf[Iterator[T]] } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 991346a40af4e29bbe300fd7db97b7a3fb404ff8..fcda9fa65303acbbe9de5554c1a17f144d910647 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io._ import java.nio.ByteBuffer +import java.nio.channels.Channels import scala.collection.mutable import scala.collection.mutable.HashMap @@ -35,7 +36,7 @@ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.network._ -import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo @@ -55,6 +56,55 @@ private[spark] class BlockResult( val readMethod: DataReadMethod.Value, val bytes: Long) +/** + * Abstracts away how blocks are stored and provides different ways to read the underlying block + * data. Callers should call [[dispose()]] when they're done with the block. + */ +private[spark] trait BlockData { + + def toInputStream(): InputStream + + /** + * Returns a Netty-friendly wrapper for the block's data. + * + * @see [[ManagedBuffer#convertToNetty()]] + */ + def toNetty(): Object + + def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer + + def toByteBuffer(): ByteBuffer + + def size: Long + + def dispose(): Unit + +} + +private[spark] class ByteBufferBlockData( + val buffer: ChunkedByteBuffer, + val shouldDispose: Boolean) extends BlockData { + + override def toInputStream(): InputStream = buffer.toInputStream(dispose = false) + + override def toNetty(): Object = buffer.toNetty + + override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = { + buffer.copy(allocator) + } + + override def toByteBuffer(): ByteBuffer = buffer.toByteBuffer + + override def size: Long = buffer.size + + override def dispose(): Unit = { + if (shouldDispose) { + buffer.dispose() + } + } + +} + /** * Manager running on every node (driver and executors) which provides interfaces for putting and * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). @@ -94,7 +144,7 @@ private[spark] class BlockManager( // Actual storage of where blocks are kept private[spark] val memoryStore = new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this) - private[spark] val diskStore = new DiskStore(conf, diskBlockManager) + private[spark] val diskStore = new DiskStore(conf, diskBlockManager, securityManager) memoryManager.setMemoryStore(memoryStore) // Note: depending on the memory manager, `maxMemory` may actually vary over time. @@ -304,7 +354,8 @@ private[spark] class BlockManager( shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { getLocalBytes(blockId) match { - case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer) + case Some(blockData) => + new BlockManagerManagedBuffer(blockInfoManager, blockId, blockData, true) case None => // If this block manager receives a request for a block that it doesn't have then it's // likely that the master has outdated block statuses for this block. Therefore, we send @@ -463,21 +514,22 @@ private[spark] class BlockManager( val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) } else if (level.useDisk && diskStore.contains(blockId)) { + val diskData = diskStore.getBytes(blockId) val iterToReturn: Iterator[Any] = { - val diskBytes = diskStore.getBytes(blockId) if (level.deserialized) { val diskValues = serializerManager.dataDeserializeStream( blockId, - diskBytes.toInputStream(dispose = true))(info.classTag) + diskData.toInputStream())(info.classTag) maybeCacheDiskValuesInMemory(info, blockId, level, diskValues) } else { - val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes) - .map {_.toInputStream(dispose = false)} - .getOrElse { diskBytes.toInputStream(dispose = true) } + val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskData) + .map { _.toInputStream(dispose = false) } + .getOrElse { diskData.toInputStream() } serializerManager.dataDeserializeStream(blockId, stream)(info.classTag) } } - val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId)) + val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, + releaseLockAndDispose(blockId, diskData)) Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) } else { handleLocalReadFailure(blockId) @@ -488,7 +540,7 @@ private[spark] class BlockManager( /** * Get block from the local block manager as serialized bytes. */ - def getLocalBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { + def getLocalBytes(blockId: BlockId): Option[BlockData] = { logDebug(s"Getting local block $blockId as bytes") // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work @@ -496,9 +548,9 @@ private[spark] class BlockManager( val shuffleBlockResolver = shuffleManager.shuffleBlockResolver // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. - Option( - new ChunkedByteBuffer( - shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())) + val buf = new ChunkedByteBuffer( + shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) + Some(new ByteBufferBlockData(buf, true)) } else { blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) } } @@ -510,7 +562,7 @@ private[spark] class BlockManager( * Must be called while holding a read lock on the block. * Releases the read lock upon exception; keeps the read lock upon successful return. */ - private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): ChunkedByteBuffer = { + private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): BlockData = { val level = info.level logDebug(s"Level for block $blockId is $level") // In order, try to read the serialized bytes from memory, then from disk, then fall back to @@ -525,17 +577,19 @@ private[spark] class BlockManager( diskStore.getBytes(blockId) } else if (level.useMemory && memoryStore.contains(blockId)) { // The block was not found on disk, so serialize an in-memory copy: - serializerManager.dataSerializeWithExplicitClassTag( - blockId, memoryStore.getValues(blockId).get, info.classTag) + new ByteBufferBlockData(serializerManager.dataSerializeWithExplicitClassTag( + blockId, memoryStore.getValues(blockId).get, info.classTag), true) } else { handleLocalReadFailure(blockId) } } else { // storage level is serialized if (level.useMemory && memoryStore.contains(blockId)) { - memoryStore.getBytes(blockId).get + new ByteBufferBlockData(memoryStore.getBytes(blockId).get, false) } else if (level.useDisk && diskStore.contains(blockId)) { - val diskBytes = diskStore.getBytes(blockId) - maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes).getOrElse(diskBytes) + val diskData = diskStore.getBytes(blockId) + maybeCacheDiskBytesInMemory(info, blockId, level, diskData) + .map(new ByteBufferBlockData(_, false)) + .getOrElse(diskData) } else { handleLocalReadFailure(blockId) } @@ -761,43 +815,15 @@ private[spark] class BlockManager( * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing * so may corrupt or change the data stored by the `BlockManager`. * - * @param encrypt If true, asks the block manager to encrypt the data block before storing, - * when I/O encryption is enabled. This is required for blocks that have been - * read from unencrypted sources, since all the BlockManager read APIs - * automatically do decryption. * @return true if the block was stored or false if an error occurred. */ def putBytes[T: ClassTag]( blockId: BlockId, bytes: ChunkedByteBuffer, level: StorageLevel, - tellMaster: Boolean = true, - encrypt: Boolean = false): Boolean = { + tellMaster: Boolean = true): Boolean = { require(bytes != null, "Bytes is null") - - val bytesToStore = - if (encrypt && securityManager.ioEncryptionKey.isDefined) { - try { - val data = bytes.toByteBuffer - val in = new ByteBufferInputStream(data) - val byteBufOut = new ByteBufferOutputStream(data.remaining()) - val out = CryptoStreamUtils.createCryptoOutputStream(byteBufOut, conf, - securityManager.ioEncryptionKey.get) - try { - ByteStreams.copy(in, out) - } finally { - in.close() - out.close() - } - new ChunkedByteBuffer(byteBufOut.toByteBuffer) - } finally { - bytes.dispose() - } - } else { - bytes - } - - doPutBytes(blockId, bytesToStore, level, implicitly[ClassTag[T]], tellMaster) + doPutBytes(blockId, bytes, level, implicitly[ClassTag[T]], tellMaster) } /** @@ -828,8 +854,9 @@ private[spark] class BlockManager( val replicationFuture = if (level.replication > 1) { Future { // This is a blocking action and should run in futureExecutionContext which is a cached - // thread pool - replicate(blockId, bytes, level, classTag) + // thread pool. The ByteBufferBlockData wrapper is not disposed of to avoid releasing + // buffers that are owned by the caller. + replicate(blockId, new ByteBufferBlockData(bytes, false), level, classTag) }(futureExecutionContext) } else { null @@ -1008,8 +1035,9 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { fileOutputStream => - serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + serializerManager.dataSerializeStream(blockId, out, iter)(classTag) } size = diskStore.getSize(blockId) } else { @@ -1024,8 +1052,9 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { fileOutputStream => - partiallySerializedValues.finishWritingToStream(fileOutputStream) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + partiallySerializedValues.finishWritingToStream(out) } size = diskStore.getSize(blockId) } else { @@ -1035,8 +1064,9 @@ private[spark] class BlockManager( } } else if (level.useDisk) { - diskStore.put(blockId) { fileOutputStream => - serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + serializerManager.dataSerializeStream(blockId, out, iterator())(classTag) } size = diskStore.getSize(blockId) } @@ -1065,7 +1095,7 @@ private[spark] class BlockManager( try { replicate(blockId, bytesToReplicate, level, remoteClassTag) } finally { - bytesToReplicate.unmap() + bytesToReplicate.dispose() } logDebug("Put block %s remotely took %s" .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) @@ -1089,29 +1119,29 @@ private[spark] class BlockManager( blockInfo: BlockInfo, blockId: BlockId, level: StorageLevel, - diskBytes: ChunkedByteBuffer): Option[ChunkedByteBuffer] = { + diskData: BlockData): Option[ChunkedByteBuffer] = { require(!level.deserialized) if (level.useMemory) { // Synchronize on blockInfo to guard against a race condition where two readers both try to // put values read from disk into the MemoryStore. blockInfo.synchronized { if (memoryStore.contains(blockId)) { - diskBytes.dispose() + diskData.dispose() Some(memoryStore.getBytes(blockId).get) } else { val allocator = level.memoryMode match { case MemoryMode.ON_HEAP => ByteBuffer.allocate _ case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ } - val putSucceeded = memoryStore.putBytes(blockId, diskBytes.size, level.memoryMode, () => { + val putSucceeded = memoryStore.putBytes(blockId, diskData.size, level.memoryMode, () => { // https://issues.apache.org/jira/browse/SPARK-6076 // If the file size is bigger than the free memory, OOM will happen. So if we // cannot put it into MemoryStore, copyForMemory should not be created. That's why // this action is put into a `() => ChunkedByteBuffer` and created lazily. - diskBytes.copy(allocator) + diskData.toChunkedByteBuffer(allocator) }) if (putSucceeded) { - diskBytes.dispose() + diskData.dispose() Some(memoryStore.getBytes(blockId).get) } else { None @@ -1203,7 +1233,7 @@ private[spark] class BlockManager( replicate(blockId, data, storageLevel, info.classTag, existingReplicas) } finally { logDebug(s"Releasing lock for $blockId") - releaseLock(blockId) + releaseLockAndDispose(blockId, data) } } } @@ -1214,7 +1244,7 @@ private[spark] class BlockManager( */ private def replicate( blockId: BlockId, - data: ChunkedByteBuffer, + data: BlockData, level: StorageLevel, classTag: ClassTag[_], existingReplicas: Set[BlockManagerId] = Set.empty): Unit = { @@ -1256,7 +1286,7 @@ private[spark] class BlockManager( peer.port, peer.executorId, blockId, - new NettyManagedBuffer(data.toNetty), + new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false), tLevel, classTag) logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" + @@ -1339,10 +1369,11 @@ private[spark] class BlockManager( logInfo(s"Writing block $blockId to disk") data() match { case Left(elements) => - diskStore.put(blockId) { fileOutputStream => + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) serializerManager.dataSerializeStream( blockId, - fileOutputStream, + out, elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]]) } case Right(bytes) => @@ -1434,6 +1465,11 @@ private[spark] class BlockManager( } } + def releaseLockAndDispose(blockId: BlockId, data: BlockData): Unit = { + blockInfoManager.unlock(blockId) + data.dispose() + } + def stop(): Unit = { blockTransferService.close() if (shuffleClient ne blockTransferService) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala index f66f94279855064c0fde3976d945d217e4b6f644..1ea0d378cbe871ffc29adcc1091d3d78875d367d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -17,31 +17,52 @@ package org.apache.spark.storage -import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import java.io.InputStream +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.util.io.ChunkedByteBuffer /** - * This [[ManagedBuffer]] wraps a [[ChunkedByteBuffer]] retrieved from the [[BlockManager]] + * This [[ManagedBuffer]] wraps a [[BlockData]] instance retrieved from the [[BlockManager]] * so that the corresponding block's read lock can be released once this buffer's references * are released. * + * If `dispose` is set to true, the [[BlockData]]will be disposed when the buffer's reference + * count drops to zero. + * * This is effectively a wrapper / bridge to connect the BlockManager's notion of read locks * to the network layer's notion of retain / release counts. */ private[storage] class BlockManagerManagedBuffer( blockInfoManager: BlockInfoManager, blockId: BlockId, - chunkedBuffer: ChunkedByteBuffer) extends NettyManagedBuffer(chunkedBuffer.toNetty) { + data: BlockData, + dispose: Boolean) extends ManagedBuffer { + + private val refCount = new AtomicInteger(1) + + override def size(): Long = data.size + + override def nioByteBuffer(): ByteBuffer = data.toByteBuffer() + + override def createInputStream(): InputStream = data.toInputStream() + + override def convertToNetty(): Object = data.toNetty() override def retain(): ManagedBuffer = { - super.retain() + refCount.incrementAndGet() val locked = blockInfoManager.lockForReading(blockId, blocking = false) assert(locked.isDefined) this - } + } override def release(): ManagedBuffer = { blockInfoManager.unlock(blockId) - super.release() + if (refCount.decrementAndGet() == 0 && dispose) { + data.dispose() + } + this } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index ca23e2391ed0294314daae05800663896fad483e..c6656341fcd15ceef821779e0b5de58f61456963 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,48 +17,67 @@ package org.apache.spark.storage -import java.io.{FileOutputStream, IOException, RandomAccessFile} +import java.io._ import java.nio.ByteBuffer +import java.nio.channels.{Channels, ReadableByteChannel, WritableByteChannel} import java.nio.channels.FileChannel.MapMode +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.concurrent.ConcurrentHashMap -import com.google.common.io.Closeables +import scala.collection.mutable.ListBuffer -import org.apache.spark.SparkConf +import com.google.common.io.{ByteStreams, Closeables, Files} +import io.netty.channel.FileRegion +import io.netty.util.AbstractReferenceCounted + +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.security.CryptoStreamUtils +import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.ChunkedByteBuffer /** * Stores BlockManager blocks on disk. */ -private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) extends Logging { +private[spark] class DiskStore( + conf: SparkConf, + diskManager: DiskBlockManager, + securityManager: SecurityManager) extends Logging { private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") + private val blockSizes = new ConcurrentHashMap[String, Long]() - def getSize(blockId: BlockId): Long = { - diskManager.getFile(blockId.name).length - } + def getSize(blockId: BlockId): Long = blockSizes.get(blockId.name) /** * Invokes the provided callback function to write the specific block. * * @throws IllegalStateException if the block already exists in the disk store. */ - def put(blockId: BlockId)(writeFunc: FileOutputStream => Unit): Unit = { + def put(blockId: BlockId)(writeFunc: WritableByteChannel => Unit): Unit = { if (contains(blockId)) { throw new IllegalStateException(s"Block $blockId is already present in the disk store") } logDebug(s"Attempting to put block $blockId") val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) - val fileOutputStream = new FileOutputStream(file) + val out = new CountingWritableChannel(openForWrite(file)) var threwException: Boolean = true try { - writeFunc(fileOutputStream) + writeFunc(out) + blockSizes.put(blockId.name, out.getCount) threwException = false } finally { try { - Closeables.close(fileOutputStream, threwException) + out.close() + } catch { + case ioe: IOException => + if (!threwException) { + threwException = true + throw ioe + } } finally { if (threwException) { remove(blockId) @@ -73,41 +92,46 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e } def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = { - put(blockId) { fileOutputStream => - val channel = fileOutputStream.getChannel - Utils.tryWithSafeFinally { - bytes.writeFully(channel) - } { - channel.close() - } + put(blockId) { channel => + bytes.writeFully(channel) } } - def getBytes(blockId: BlockId): ChunkedByteBuffer = { + def getBytes(blockId: BlockId): BlockData = { val file = diskManager.getFile(blockId.name) - val channel = new RandomAccessFile(file, "r").getChannel - Utils.tryWithSafeFinally { - // For small files, directly read rather than memory map - if (file.length < minMemoryMapBytes) { - val buf = ByteBuffer.allocate(file.length.toInt) - channel.position(0) - while (buf.remaining() != 0) { - if (channel.read(buf) == -1) { - throw new IOException("Reached EOF before filling buffer\n" + - s"offset=0\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}") + val blockSize = getSize(blockId) + + securityManager.getIOEncryptionKey() match { + case Some(key) => + // Encrypted blocks cannot be memory mapped; return a special object that does decryption + // and provides InputStream / FileRegion implementations for reading the data. + new EncryptedBlockData(file, blockSize, conf, key) + + case _ => + val channel = new FileInputStream(file).getChannel() + if (blockSize < minMemoryMapBytes) { + // For small files, directly read rather than memory map. + Utils.tryWithSafeFinally { + val buf = ByteBuffer.allocate(blockSize.toInt) + JavaUtils.readFully(channel, buf) + buf.flip() + new ByteBufferBlockData(new ChunkedByteBuffer(buf), true) + } { + channel.close() + } + } else { + Utils.tryWithSafeFinally { + new ByteBufferBlockData( + new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)), true) + } { + channel.close() } } - buf.flip() - new ChunkedByteBuffer(buf) - } else { - new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)) - } - } { - channel.close() } } def remove(blockId: BlockId): Boolean = { + blockSizes.remove(blockId.name) val file = diskManager.getFile(blockId.name) if (file.exists()) { val ret = file.delete() @@ -124,4 +148,142 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e val file = diskManager.getFile(blockId.name) file.exists() } + + private def openForWrite(file: File): WritableByteChannel = { + val out = new FileOutputStream(file).getChannel() + try { + securityManager.getIOEncryptionKey().map { key => + CryptoStreamUtils.createWritableChannel(out, conf, key) + }.getOrElse(out) + } catch { + case e: Exception => + Closeables.close(out, true) + file.delete() + throw e + } + } + +} + +private class EncryptedBlockData( + file: File, + blockSize: Long, + conf: SparkConf, + key: Array[Byte]) extends BlockData { + + override def toInputStream(): InputStream = Channels.newInputStream(open()) + + override def toNetty(): Object = new ReadableChannelFileRegion(open(), blockSize) + + override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = { + val source = open() + try { + var remaining = blockSize + val chunks = new ListBuffer[ByteBuffer]() + while (remaining > 0) { + val chunkSize = math.min(remaining, Int.MaxValue) + val chunk = allocator(chunkSize.toInt) + remaining -= chunkSize + JavaUtils.readFully(source, chunk) + chunk.flip() + chunks += chunk + } + + new ChunkedByteBuffer(chunks.toArray) + } finally { + source.close() + } + } + + override def toByteBuffer(): ByteBuffer = { + // This is used by the block transfer service to replicate blocks. The upload code reads + // all bytes into memory to send the block to the remote executor, so it's ok to do this + // as long as the block fits in a Java array. + assert(blockSize <= Int.MaxValue, "Block is too large to be wrapped in a byte buffer.") + val dst = ByteBuffer.allocate(blockSize.toInt) + val in = open() + try { + JavaUtils.readFully(in, dst) + dst.flip() + dst + } finally { + Closeables.close(in, true) + } + } + + override def size: Long = blockSize + + override def dispose(): Unit = { } + + private def open(): ReadableByteChannel = { + val channel = new FileInputStream(file).getChannel() + try { + CryptoStreamUtils.createReadableChannel(channel, conf, key) + } catch { + case e: Exception => + Closeables.close(channel, true) + throw e + } + } + +} + +private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long) + extends AbstractReferenceCounted with FileRegion { + + private var _transferred = 0L + + private val buffer = ByteBuffer.allocateDirect(64 * 1024) + buffer.flip() + + override def count(): Long = blockSize + + override def position(): Long = 0 + + override def transfered(): Long = _transferred + + override def transferTo(target: WritableByteChannel, pos: Long): Long = { + assert(pos == transfered(), "Invalid position.") + + var written = 0L + var lastWrite = -1L + while (lastWrite != 0) { + if (!buffer.hasRemaining()) { + buffer.clear() + source.read(buffer) + buffer.flip() + } + if (buffer.hasRemaining()) { + lastWrite = target.write(buffer) + written += lastWrite + } else { + lastWrite = 0 + } + } + + _transferred += written + written + } + + override def deallocate(): Unit = source.close() +} + +private class CountingWritableChannel(sink: WritableByteChannel) extends WritableByteChannel { + + private var count = 0L + + def getCount: Long = count + + override def write(src: ByteBuffer): Int = { + val written = sink.write(src) + if (written > 0) { + count += written + } + written + } + + override def isOpen(): Boolean = sink.isOpen() + + override def close(): Unit = sink.close() + } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 5efdd23f79a21df90292b9eebebe0dc32dcaf162..241aacd74b5860d5887c0fb27a5dfd199f95416a 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -236,14 +236,6 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** Helper methods for storage-related objects. */ private[spark] object StorageUtils extends Logging { - // Ewwww... Reflection!!! See the unmap method for justification - private val memoryMappedBufferFileDescriptorField = { - val mappedBufferClass = classOf[java.nio.MappedByteBuffer] - val fdField = mappedBufferClass.getDeclaredField("fd") - fdField.setAccessible(true) - fdField - } - /** * Attempt to clean up a ByteBuffer if it is direct or memory-mapped. This uses an *unsafe* Sun * API that will cause errors if one attempts to read from the disposed buffer. However, neither @@ -251,8 +243,6 @@ private[spark] object StorageUtils extends Logging { * pressure on the garbage collector. Waiting for garbage collection may lead to the depletion of * off-heap memory or huge numbers of open files. There's unfortunately no standard API to * manually dispose of these kinds of buffers. - * - * See also [[unmap]] */ def dispose(buffer: ByteBuffer): Unit = { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { @@ -261,28 +251,6 @@ private[spark] object StorageUtils extends Logging { } } - /** - * Attempt to unmap a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that will - * cause errors if one attempts to read from the unmapped buffer. However, the file descriptors of - * memory-mapped buffers do not put pressure on the garbage collector. Waiting for garbage - * collection may lead to huge numbers of open files. There's unfortunately no standard API to - * manually unmap memory-mapped buffers. - * - * See also [[dispose]] - */ - def unmap(buffer: ByteBuffer): Unit = { - if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - // Note that direct buffers are instances of MappedByteBuffer. As things stand in Java 8, the - // JDK does not provide a public API to distinguish between direct buffers and memory-mapped - // buffers. As an alternative, we peek beneath the curtains and look for a non-null file - // descriptor in mappedByteBuffer - if (memoryMappedBufferFileDescriptorField.get(buffer) != null) { - logTrace(s"Unmapping $buffer") - cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer]) - } - } - } - private def cleanDirectBuffer(buffer: DirectBuffer) = { val cleaner = buffer.cleaner() if (cleaner != null) { diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index fb54dd66a39a90cc129ce0dd0c2d4ab312838e22..90e3af2d0ec74e116d1c0b3679422ccfcf802b1f 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -344,7 +344,7 @@ private[spark] class MemoryStore( val serializationStream: SerializationStream = { val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() - ser.serializeStream(serializerManager.wrapStream(blockId, redirectableStream)) + ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) } // Request enough memory to begin unrolling diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 1667516663b35b0cd5e1bdc4d046cc4c2259dec9..2f905c8af0f6394d5a7d0c4d3fcf587fb4b10aff 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -138,8 +138,6 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { /** * Attempt to clean up any ByteBuffer in this ChunkedByteBuffer which is direct or memory-mapped. * See [[StorageUtils.dispose]] for more information. - * - * See also [[unmap]] */ def dispose(): Unit = { if (!disposed) { @@ -148,18 +146,6 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } } - /** - * Attempt to unmap any ByteBuffer in this ChunkedByteBuffer if it is memory-mapped. See - * [[StorageUtils.unmap]] for more information. - * - * See also [[dispose]] - */ - def unmap(): Unit = { - if (!disposed) { - chunks.foreach(StorageUtils.unmap) - disposed = true - } - } } /** diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 4e36adc8baf3fcc6adc6723bf1955a035ecb7c42..84f7f1fc8eb096d123535e1fe2afdf79e6c8eab1 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} +import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.io.ChunkedByteBuffer @@ -28,7 +29,8 @@ class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} -class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext { +class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext + with EncryptionFunSuite { val clusterUrl = "local-cluster[2,1,1024]" @@ -149,8 +151,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex sc.parallelize(1 to 10).count() } - private def testCaching(storageLevel: StorageLevel): Unit = { - sc = new SparkContext(clusterUrl, "test") + private def testCaching(conf: SparkConf, storageLevel: StorageLevel): Unit = { + sc = new SparkContext(conf.setMaster(clusterUrl).setAppName("test")) sc.jobProgressListener.waitUntilExecutorsUp(2, 30000) val data = sc.parallelize(1 to 1000, 10) val cachedData = data.persist(storageLevel) @@ -187,8 +189,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex "caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2, "caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2 ).foreach { case (testName, storageLevel) => - test(testName) { - testCaching(storageLevel) + encryptionTest(testName) { conf => + testCaching(conf, storageLevel) } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 6646068d5080ba89d7b27451ec296c5c8eca7fa6..82760fe92f76ade28caf0e2a65ac18a6077dec18 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -24,8 +24,10 @@ import org.scalatest.Assertions import org.apache.spark._ import org.apache.spark.io.SnappyCompressionCodec import org.apache.spark.rdd.RDD +import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage._ +import org.apache.spark.util.io.ChunkedByteBuffer // Dummy class that creates a broadcast variable but doesn't use it class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { @@ -43,7 +45,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { } } -class BroadcastSuite extends SparkFunSuite with LocalSparkContext { +class BroadcastSuite extends SparkFunSuite with LocalSparkContext with EncryptionFunSuite { test("Using TorrentBroadcast locally") { sc = new SparkContext("local", "test") @@ -61,9 +63,8 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } - test("Accessing TorrentBroadcast variables in a local cluster") { + encryptionTest("Accessing TorrentBroadcast variables in a local cluster") { conf => val numSlaves = 4 - val conf = new SparkConf conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) @@ -85,7 +86,9 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val size = 1 + rand.nextInt(1024 * 10) val data: Array[Byte] = new Array[Byte](size) rand.nextBytes(data) - val blocks = blockifyObject(data, blockSize, serializer, compressionCodec) + val blocks = blockifyObject(data, blockSize, serializer, compressionCodec).map { b => + new ChunkedByteBuffer(b).toInputStream(dispose = true) + } val unblockified = unBlockifyObject[Array[Byte]](blocks, serializer, compressionCodec) assert(unblockified === data) } @@ -137,9 +140,8 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } - test("Cache broadcast to disk") { - val conf = new SparkConf() - .setMaster("local") + encryptionTest("Cache broadcast to disk") { conf => + conf.setMaster("local") .setAppName("test") .set("spark.memory.useLegacyMode", "true") .set("spark.storage.memoryFraction", "0.0") diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 0f3a4a03618eddbb10b3ce8a9cf1a3a490dab7e7..608052f5ed855609c8968604600448b4c0cd3701 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -16,9 +16,11 @@ */ package org.apache.spark.security -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} +import java.nio.channels.Channels import java.nio.charset.StandardCharsets.UTF_8 -import java.util.UUID +import java.nio.file.Files +import java.util.{Arrays, Random, UUID} import com.google.common.io.ByteStreams @@ -121,6 +123,46 @@ class CryptoStreamUtilsSuite extends SparkFunSuite { } } + test("crypto stream wrappers") { + val testData = new Array[Byte](128 * 1024) + new Random().nextBytes(testData) + + val conf = createConf() + val key = createKey(conf) + val file = Files.createTempFile("crypto", ".test").toFile() + + val outStream = createCryptoOutputStream(new FileOutputStream(file), conf, key) + try { + ByteStreams.copy(new ByteArrayInputStream(testData), outStream) + } finally { + outStream.close() + } + + val inStream = createCryptoInputStream(new FileInputStream(file), conf, key) + try { + val inStreamData = ByteStreams.toByteArray(inStream) + assert(Arrays.equals(inStreamData, testData)) + } finally { + inStream.close() + } + + val outChannel = createWritableChannel(new FileOutputStream(file).getChannel(), conf, key) + try { + val inByteChannel = Channels.newChannel(new ByteArrayInputStream(testData)) + ByteStreams.copy(inByteChannel, outChannel) + } finally { + outChannel.close() + } + + val inChannel = createReadableChannel(new FileInputStream(file).getChannel(), conf, key) + try { + val inChannelData = ByteStreams.toByteArray(Channels.newInputStream(inChannel)) + assert(Arrays.equals(inChannelData, testData)) + } finally { + inChannel.close() + } + } + private def createConf(extra: (String, String)*): SparkConf = { val conf = new SparkConf() extra.foreach { case (k, v) => conf.set(k, v) } diff --git a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..3f52dc41abf6dadfe7d40b4c661c50795754523b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.security + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ + +trait EncryptionFunSuite { + + this: SparkFunSuite => + + /** + * Runs a test twice, initializing a SparkConf object with encryption off, then on. It's ok + * for the test to modify the provided SparkConf. + */ + final protected def encryptionTest(name: String)(fn: SparkConf => Unit) { + Seq(false, true).foreach { encrypt => + test(s"$name (encryption = ${ if (encrypt) "on" else "off" })") { + val conf = new SparkConf().set(IO_ENCRYPTION_ENABLED, encrypt) + fn(conf) + } + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 64a67b4c4cbab542d3ef9214de6dc639db032b8b..a8b960489983823e05315df866a37517caedab58 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.concurrent.Timeouts._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod +import org.apache.spark.internal.config._ import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.{BlockDataManager, BlockTransferService} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} @@ -42,6 +43,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -49,7 +51,8 @@ import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach - with PrivateMethodTester with LocalSparkContext with ResetSystemProperties { + with PrivateMethodTester with LocalSparkContext with ResetSystemProperties + with EncryptionFunSuite { import BlockManagerSuite._ @@ -75,16 +78,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER, master: BlockManagerMaster = this.master, - transferService: Option[BlockTransferService] = Option.empty): BlockManager = { - conf.set("spark.testing.memory", maxMem.toString) - conf.set("spark.memory.offHeap.size", maxMem.toString) - val serializer = new KryoSerializer(conf) + transferService: Option[BlockTransferService] = Option.empty, + testConf: Option[SparkConf] = None): BlockManager = { + val bmConf = testConf.map(_.setAll(conf.getAll)).getOrElse(conf) + bmConf.set("spark.testing.memory", maxMem.toString) + bmConf.set("spark.memory.offHeap.size", maxMem.toString) + val serializer = new KryoSerializer(bmConf) + val encryptionKey = if (bmConf.get(IO_ENCRYPTION_ENABLED)) { + Some(CryptoStreamUtils.createKey(bmConf)) + } else { + None + } + val bmSecurityMgr = new SecurityManager(bmConf, encryptionKey) val transfer = transferService .getOrElse(new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1)) - val memManager = UnifiedMemoryManager(conf, numCores = 1) - val serializerManager = new SerializerManager(serializer, conf) - val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, conf, - memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + val memManager = UnifiedMemoryManager(bmConf, numCores = 1) + val serializerManager = new SerializerManager(serializer, bmConf) + val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, bmConf, + memManager, mapOutputTracker, shuffleManager, transfer, bmSecurityMgr, 0) memManager.setMemoryStore(blockManager.memoryStore) blockManager.initialize("app-id") blockManager @@ -610,8 +621,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store") } - test("on-disk storage") { - store = makeBlockManager(1200) + encryptionTest("on-disk storage") { _conf => + store = makeBlockManager(1200, testConf = Some(_conf)) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -623,34 +634,35 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was in store") } - test("disk and memory storage") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false) + encryptionTest("disk and memory storage") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false, testConf = conf) } - test("disk and memory storage with getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true) + encryptionTest("disk and memory storage with getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true, testConf = conf) } - test("disk and memory storage with serialization") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false) + encryptionTest("disk and memory storage with serialization") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false, testConf = conf) } - test("disk and memory storage with serialization and getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true) + encryptionTest("disk and memory storage with serialization and getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true, testConf = conf) } - test("disk and off-heap memory storage") { - testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false) + encryptionTest("disk and off-heap memory storage") { _conf => + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false, testConf = conf) } - test("disk and off-heap memory storage with getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true) + encryptionTest("disk and off-heap memory storage with getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true, testConf = conf) } def testDiskAndMemoryStorage( storageLevel: StorageLevel, - getAsBytes: Boolean): Unit = { - store = makeBlockManager(12000) + getAsBytes: Boolean, + testConf: SparkConf): Unit = { + store = makeBlockManager(12000, testConf = Some(testConf)) val accessMethod = if (getAsBytes) store.getLocalBytesAndReleaseLock else store.getSingleAndReleaseLock val a1 = new Array[Byte](4000) @@ -678,8 +690,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } - test("LRU with mixed storage levels") { - store = makeBlockManager(12000) + encryptionTest("LRU with mixed storage levels") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) @@ -700,8 +712,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getSingleAndReleaseLock("a4").isDefined, "a4 was not in store") } - test("in-memory LRU with streams") { - store = makeBlockManager(12000) + encryptionTest("in-memory LRU with streams") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) @@ -728,8 +740,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getAndReleaseLock("list3") === None, "list1 was in store") } - test("LRU with mixed storage levels and streams") { - store = makeBlockManager(12000) + encryptionTest("LRU with mixed storage levels and streams") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) @@ -1325,7 +1337,8 @@ private object BlockManagerSuite { val getAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.get) val getSingleAndReleaseLock: (BlockId) => Option[Any] = wrapGet(store.getSingle) val getLocalBytesAndReleaseLock: (BlockId) => Option[ChunkedByteBuffer] = { - wrapGet(store.getLocalBytes) + val allocator = ByteBuffer.allocate _ + wrapGet { bid => store.getLocalBytes(bid).map(_.toChunkedByteBuffer(allocator)) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 9e6b02b9eac4d8ea474b83f7f8021f97e0bbed93..67fc084e8a13d3da548ba0b73bc8339840e3aa6e 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -18,15 +18,23 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} -import java.util.Arrays +import java.util.{Arrays, Random} -import org.apache.spark.{SparkConf, SparkFunSuite} +import com.google.common.io.{ByteStreams, Files} +import io.netty.channel.FileRegion + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.{ByteArrayWritableChannel, JavaUtils} +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.io.ChunkedByteBuffer import org.apache.spark.util.Utils class DiskStoreSuite extends SparkFunSuite { test("reads of memory-mapped and non memory-mapped files are equivalent") { + val conf = new SparkConf() + val securityManager = new SecurityManager(conf) + // It will cause error when we tried to re-open the filestore and the // memory-mapped byte buffer tot he file has not been GC on Windows. assume(!Utils.isWindows) @@ -37,16 +45,18 @@ class DiskStoreSuite extends SparkFunSuite { val byteBuffer = new ChunkedByteBuffer(ByteBuffer.wrap(bytes)) val blockId = BlockId("rdd_1_2") - val diskBlockManager = new DiskBlockManager(new SparkConf(), deleteFilesOnStop = true) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) - val diskStoreMapped = new DiskStore(new SparkConf().set(confKey, "0"), diskBlockManager) + val diskStoreMapped = new DiskStore(conf.clone().set(confKey, "0"), diskBlockManager, + securityManager) diskStoreMapped.putBytes(blockId, byteBuffer) - val mapped = diskStoreMapped.getBytes(blockId) + val mapped = diskStoreMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer assert(diskStoreMapped.remove(blockId)) - val diskStoreNotMapped = new DiskStore(new SparkConf().set(confKey, "1m"), diskBlockManager) + val diskStoreNotMapped = new DiskStore(conf.clone().set(confKey, "1m"), diskBlockManager, + securityManager) diskStoreNotMapped.putBytes(blockId, byteBuffer) - val notMapped = diskStoreNotMapped.getBytes(blockId) + val notMapped = diskStoreNotMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer // Not possible to do isInstanceOf due to visibility of HeapByteBuffer assert(notMapped.getChunks().forall(_.getClass.getName.endsWith("HeapByteBuffer")), @@ -63,4 +73,95 @@ class DiskStoreSuite extends SparkFunSuite { assert(Arrays.equals(mapped.toArray, bytes)) assert(Arrays.equals(notMapped.toArray, bytes)) } + + test("block size tracking") { + val conf = new SparkConf() + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) + + val blockId = BlockId("rdd_1_2") + diskStore.put(blockId) { chan => + val buf = ByteBuffer.wrap(new Array[Byte](32)) + while (buf.hasRemaining()) { + chan.write(buf) + } + } + + assert(diskStore.getSize(blockId) === 32L) + diskStore.remove(blockId) + assert(diskStore.getSize(blockId) === 0L) + } + + test("block data encryption") { + val testDir = Utils.createTempDir() + val testData = new Array[Byte](128 * 1024) + new Random().nextBytes(testData) + + val conf = new SparkConf() + val securityManager = new SecurityManager(conf, Some(CryptoStreamUtils.createKey(conf))) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskStore = new DiskStore(conf, diskBlockManager, securityManager) + + val blockId = BlockId("rdd_1_2") + diskStore.put(blockId) { chan => + val buf = ByteBuffer.wrap(testData) + while (buf.hasRemaining()) { + chan.write(buf) + } + } + + assert(diskStore.getSize(blockId) === testData.length) + + val diskData = Files.toByteArray(diskBlockManager.getFile(blockId.name)) + assert(!Arrays.equals(testData, diskData)) + + val blockData = diskStore.getBytes(blockId) + assert(blockData.isInstanceOf[EncryptedBlockData]) + assert(blockData.size === testData.length) + Map( + "input stream" -> readViaInputStream _, + "chunked byte buffer" -> readViaChunkedByteBuffer _, + "nio byte buffer" -> readViaNioBuffer _, + "managed buffer" -> readViaManagedBuffer _ + ).foreach { case (name, fn) => + val readData = fn(blockData) + assert(readData.length === blockData.size, s"Size of data read via $name did not match.") + assert(Arrays.equals(testData, readData), s"Data read via $name did not match.") + } + } + + private def readViaInputStream(data: BlockData): Array[Byte] = { + val is = data.toInputStream() + try { + ByteStreams.toByteArray(is) + } finally { + is.close() + } + } + + private def readViaChunkedByteBuffer(data: BlockData): Array[Byte] = { + val buf = data.toChunkedByteBuffer(ByteBuffer.allocate _) + try { + buf.toArray + } finally { + buf.dispose() + } + } + + private def readViaNioBuffer(data: BlockData): Array[Byte] = { + JavaUtils.bufferToArray(data.toByteBuffer()) + } + + private def readViaManagedBuffer(data: BlockData): Array[Byte] = { + val region = data.toNetty().asInstanceOf[FileRegion] + val byteChannel = new ByteArrayWritableChannel(data.size.toInt) + + while (region.transfered() < region.count()) { + region.transferTo(byteChannel, region.transfered()) + } + + byteChannel.close() + byteChannel.getData + } + } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index d0864fd3678b287e8c059bcbbd9b0c13aa47956f..844760ab61d2e839d6447d24af391713957b7fc1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -158,16 +158,14 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( logInfo(s"Read partition data of $this from write ahead log, record handle " + partition.walRecordHandle) if (storeInBlockManager) { - blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel, - encrypt = true) + blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel) logDebug(s"Stored partition data of $this into block manager with level $storageLevel") dataRead.rewind() } serializerManager .dataDeserializeStream( blockId, - new ChunkedByteBuffer(dataRead).toInputStream(), - maybeEncrypted = false)(elementClassTag) + new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag) .asInstanceOf[Iterator[T]] } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 2b488038f06206fb997df67e815797969faaffc0..80c07958b41f24ab7fa2e17bfe81994bd71e1cb5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -87,8 +87,7 @@ private[streaming] class BlockManagerBasedBlockHandler( putResult case ByteBufferBlock(byteBuffer) => blockManager.putBytes( - blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true, - encrypt = true) + blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true) case o => throw new SparkException( s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}") @@ -176,11 +175,10 @@ private[streaming] class WriteAheadLogBasedBlockHandler( val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => numRecords = Some(arrayBuffer.size.toLong) - serializerManager.dataSerialize(blockId, arrayBuffer.iterator, allowEncryption = false) + serializerManager.dataSerialize(blockId, arrayBuffer.iterator) case IteratorBlock(iterator) => val countIterator = new CountingIterator(iterator) - val serializedBlock = serializerManager.dataSerialize(blockId, countIterator, - allowEncryption = false) + val serializedBlock = serializerManager.dataSerialize(blockId, countIterator) numRecords = countIterator.count serializedBlock case ByteBufferBlock(byteBuffer) => @@ -195,8 +193,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( blockId, serializedBlock, effectiveStorageLevel, - tellMaster = true, - encrypt = true) + tellMaster = true) if (!putSucceeded) { throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index c2b0389b8c6f0bdc79f0f356681c7f3c721c9b77..3c4a2716caf9012e5e96b01028164ba1f14a9e6e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -175,8 +175,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) reader.close() serializerManager.dataDeserializeStream( generateBlockId(), - new ChunkedByteBuffer(bytes).toInputStream(), - maybeEncrypted = false)(ClassTag.Any).toList + new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList } loggedData shouldEqual data } @@ -357,7 +356,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) } def dataToByteBuffer(b: Seq[String]) = - serializerManager.dataSerialize(generateBlockId, b.iterator, allowEncryption = false) + serializerManager.dataSerialize(generateBlockId, b.iterator) val blocks = data.grouped(10).toSeq diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 2ac0dc96916c551109bbe61e3a2612bb13e05aba..aa69be7ca993912effc4e4d166029e0c089e6db6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -250,8 +250,7 @@ class WriteAheadLogBackedBlockRDDSuite require(blockData.size === blockIds.size) val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) val segments = blockData.zip(blockIds).map { case (data, id) => - writer.write(serializerManager.dataSerialize(id, data.iterator, allowEncryption = false) - .toByteBuffer) + writer.write(serializerManager.dataSerialize(id, data.iterator).toByteBuffer) } writer.close() segments