diff --git a/core/pom.xml b/core/pom.xml index c04cf7e5255f250ee560bf45b49a5416450a96c1..69a0b0ff27c39ac2590da3cf258a2235ce652484 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -327,6 +327,10 @@ <groupId>org.apache.spark</groupId> <artifactId>spark-tags_${scala.binary.version}</artifactId> </dependency> + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-crypto</artifactId> + </dependency> </dependencies> <build> <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index d048cf7aeb5f18aad0a226d9ee06b2d64b78dfd3..2875b0d69def6288c13fbaeb44e481dc07f64470 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -72,7 +72,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file), (int) bufferSizeBytes); try { - this.in = serializerManager.wrapForCompression(blockId, bs); + this.in = serializerManager.wrapStream(blockId, bs); this.din = new DataInputStream(this.in); numRecords = numRecordsRemaining = din.readInt(); } catch (IOException e) { diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index a6550b6ca8c942fafe3ed12be50fcb17ae5e34ef..199365ad925a3b119ee94e46a345cb589228a2ae 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -21,15 +21,19 @@ import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} import java.security.{KeyStore, SecureRandom} import java.security.cert.X509Certificate +import javax.crypto.KeyGenerator import javax.net.ssl._ import com.google.common.hash.HashCodes import com.google.common.io.Files import org.apache.hadoop.io.Text +import org.apache.hadoop.security.Credentials import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.network.sasl.SecretKeyHolder +import org.apache.spark.security.CryptoStreamUtils._ import org.apache.spark.util.Utils /** @@ -554,4 +558,20 @@ private[spark] object SecurityManager { // key used to store the spark secret in the Hadoop UGI val SECRET_LOOKUP_KEY = "sparkCookie" + + /** + * Setup the cryptographic key used by IO encryption in credentials. The key is generated using + * [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]]. + */ + def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = { + if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) { + val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS) + val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM) + val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm) + keyGen.init(keyLen) + + val ioKey = keyGen.generateKey() + credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded) + } + } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 08d6343d623cf5bd3916c6f266b1130d1c7aaeb3..744d5d0f7aa8e7acb1be18accede84cea0445fe0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -49,6 +49,7 @@ import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ @@ -411,6 +412,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") + if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) { + throw new SparkException("IO encryption is only supported in YARN mode, please disable it " + + s"by setting ${IO_ENCRYPTION_ENABLED.key} to false") + } // "_jobProgressListener" should be set up before creating SparkEnv because when creating // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 47174e4efee81757a28f879410d7764224607bdb..ebce07c1e3b3eefd995b91990c1de86785695120 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -119,4 +119,24 @@ package object config { private[spark] val UI_RETAINED_TASKS = ConfigBuilder("spark.ui.retainedTasks") .intConf .createWithDefault(100000) + + private[spark] val IO_ENCRYPTION_ENABLED = ConfigBuilder("spark.io.encryption.enabled") + .booleanConf + .createWithDefault(false) + + private[spark] val IO_ENCRYPTION_KEYGEN_ALGORITHM = + ConfigBuilder("spark.io.encryption.keygen.algorithm") + .stringConf + .createWithDefault("HmacSHA1") + + private[spark] val IO_ENCRYPTION_KEY_SIZE_BITS = ConfigBuilder("spark.io.encryption.keySizeBits") + .intConf + .checkValues(Set(128, 192, 256)) + .createWithDefault(128) + + private[spark] val IO_CRYPTO_CIPHER_TRANSFORMATION = + ConfigBuilder("spark.io.crypto.cipher.transformation") + .internal() + .stringConf + .createWithDefaultString("AES/CTR/NoPadding") } diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..8f15f50bee8146b5ec7ef3b9c3a5edb34bf88bd5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.security + +import java.io.{InputStream, OutputStream} +import java.util.Properties +import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} + +import org.apache.commons.crypto.random._ +import org.apache.commons.crypto.stream._ +import org.apache.hadoop.io.Text + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ + +/** + * A util class for manipulating IO encryption and decryption streams. + */ +private[spark] object CryptoStreamUtils extends Logging { + /** + * Constants and variables for spark IO encryption + */ + val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN") + + // The initialization vector length in bytes. + val IV_LENGTH_IN_BYTES = 16 + // The prefix of IO encryption related configurations in Spark configuration. + val SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX = "spark.io.encryption.commons.config." + // The prefix for the configurations passing to Apache Commons Crypto library. + val COMMONS_CRYPTO_CONF_PREFIX = "commons.crypto." + + /** + * Helper method to wrap [[OutputStream]] with [[CryptoOutputStream]] for encryption. + */ + def createCryptoOutputStream( + os: OutputStream, + sparkConf: SparkConf): OutputStream = { + val properties = toCryptoConf(sparkConf) + val iv = createInitializationVector(properties) + os.write(iv) + val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() + val key = credentials.getSecretKey(SPARK_IO_TOKEN) + val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) + new CryptoOutputStream(transformationStr, properties, os, + new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + } + + /** + * Helper method to wrap [[InputStream]] with [[CryptoInputStream]] for decryption. + */ + def createCryptoInputStream( + is: InputStream, + sparkConf: SparkConf): InputStream = { + val properties = toCryptoConf(sparkConf) + val iv = new Array[Byte](IV_LENGTH_IN_BYTES) + is.read(iv, 0, iv.length) + val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() + val key = credentials.getSecretKey(SPARK_IO_TOKEN) + val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) + new CryptoInputStream(transformationStr, properties, is, + new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + } + + /** + * Get Commons-crypto configurations from Spark configurations identified by prefix. + */ + def toCryptoConf(conf: SparkConf): Properties = { + val props = new Properties() + conf.getAll.foreach { case (k, v) => + if (k.startsWith(SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX)) { + props.put(COMMONS_CRYPTO_CONF_PREFIX + k.substring( + SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.length()), v) + } + } + props + } + + /** + * This method to generate an IV (Initialization Vector) using secure random. + */ + private[this] def createInitializationVector(properties: Properties): Array[Byte] = { + val iv = new Array[Byte](IV_LENGTH_IN_BYTES) + val initialIVStart = System.currentTimeMillis() + CryptoRandomFactory.getCryptoRandom(properties).nextBytes(iv) + val initialIVFinish = System.currentTimeMillis() + val initialIVTime = initialIVFinish - initialIVStart + if (initialIVTime > 2000) { + logWarning(s"It costs ${initialIVTime} milliseconds to create the Initialization Vector " + + s"used by CryptoStream") + } + iv + } +} diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 07caadbe40438f7f5848fdc18c1ae01b73dbcae1..7b1ec6fcbbbf61dcbcd54e6eb135c5c53bbf2af5 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -23,13 +23,15 @@ import java.nio.ByteBuffer import scala.reflect.ClassTag import org.apache.spark.SparkConf +import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.storage._ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} /** - * Component which configures serialization and compression for various Spark components, including - * automatic selection of which [[Serializer]] to use for shuffles. + * Component which configures serialization, compression and encryption for various Spark + * components, including automatic selection of which [[Serializer]] to use for shuffles. */ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) { @@ -61,6 +63,9 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar // Whether to compress shuffle output temporarily spilled to disk private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) + // Whether to enable IO encryption + private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED) + /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark * program could be using a user-defined codec in a third party jar, which is loaded in @@ -102,17 +107,45 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar } } + /** + * Wrap an input stream for encryption and compression + */ + def wrapStream(blockId: BlockId, s: InputStream): InputStream = { + wrapForCompression(blockId, wrapForEncryption(s)) + } + + /** + * Wrap an output stream for encryption and compression + */ + def wrapStream(blockId: BlockId, s: OutputStream): OutputStream = { + wrapForCompression(blockId, wrapForEncryption(s)) + } + + /** + * Wrap an input stream for encryption if shuffle encryption is enabled + */ + private[this] def wrapForEncryption(s: InputStream): InputStream = { + if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s + } + + /** + * Wrap an output stream for encryption if shuffle encryption is enabled + */ + private[this] def wrapForEncryption(s: OutputStream): OutputStream = { + if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s + } + /** * Wrap an output stream for compression if block compression is enabled for its block type */ - def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } /** * Wrap an input stream for compression if block compression is enabled for its block type */ - def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } @@ -123,7 +156,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar values: Iterator[T]): Unit = { val byteStream = new BufferedOutputStream(outputStream) val ser = getSerializer(implicitly[ClassTag[T]]).newInstance() - ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() + ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() } /** Serializes into a chunked byte buffer. */ @@ -139,7 +172,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) val byteStream = new BufferedOutputStream(bbos) val ser = getSerializer(classTag).newInstance() - ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() + ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() bbos.toChunkedByteBuffer } @@ -153,7 +186,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar val stream = new BufferedInputStream(inputStream) getSerializer(implicitly[ClassTag[T]]) .newInstance() - .deserializeStream(wrapForCompression(blockId, stream)) + .deserializeStream(wrapStream(blockId, stream)) .asIterator.asInstanceOf[Iterator[T]] } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 5794f542b7564bf80c4ba9e5c047fb11b6d2954e..b9d83495d29b63eaa7af3f135f41bf2c21a21d78 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -51,9 +51,9 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) - // Wrap the streams for compression based on configuration + // Wrap the streams for compression and encryption based on configuration val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => - serializerManager.wrapForCompression(blockId, inputStream) + serializerManager.wrapStream(blockId, inputStream) } val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index fe8465279860dcba6a33654e136ecf23cb401973..c72f28e00cdbc5bc8dfb3d5b252137186caa801d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -721,10 +721,9 @@ private[spark] class BlockManager( serializerInstance: SerializerInstance, bufferSize: Int, writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { - val compressStream: OutputStream => OutputStream = - serializerManager.wrapForCompression(blockId, _) + val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream, + new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream, syncWrites, writeMetrics, blockId) } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index e5b1bf2f4b43461d3ff488bc42d21e7d92fd37aa..a499827ae159890c32659668650d15f203878899 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -39,7 +39,7 @@ private[spark] class DiskBlockObjectWriter( val file: File, serializerInstance: SerializerInstance, bufferSize: Int, - compressStream: OutputStream => OutputStream, + wrapStream: OutputStream => OutputStream, syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. @@ -115,7 +115,8 @@ private[spark] class DiskBlockObjectWriter( initialize() initialized = true } - bs = compressStream(mcs) + + bs = wrapStream(mcs) objOut = serializerInstance.serializeStream(bs) streamOpen = true this diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 586339a58d236b2bcb5e23d6a83202c560f9b618..d220ab51d115bfcbc24f2deb03bff153a75630a9 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -330,7 +330,7 @@ private[spark] class MemoryStore( redirectableStream.setOutputStream(bbos) val serializationStream: SerializationStream = { val ser = serializerManager.getSerializer(classTag).newInstance() - ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) + ser.serializeStream(serializerManager.wrapStream(blockId, redirectableStream)) } // Request enough memory to begin unrolling diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 8c8860bb37a4042e2e8014092b261b34672c5396..09435281194b58c94d2064e0ad4a69848a9e49c1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -486,8 +486,8 @@ class ExternalAppendOnlyMap[K, V, C]( ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val compressedStream = serializerManager.wrapForCompression(blockId, bufferedStream) - ser.deserializeStream(compressedStream) + val wrappedStream = serializerManager.wrapStream(blockId, bufferedStream) + ser.deserializeStream(wrappedStream) } else { // No more batches left cleanup() diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 7c98e8cabb22918ca54c9bfc265bfb4b788463ca..3579918fac45ff3d71209c13fbd163bf97a217b4 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -28,7 +28,6 @@ import com.google.common.io.ByteStreams import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging -import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer._ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} @@ -522,8 +521,9 @@ private[spark] class ExternalSorter[K, V, C]( ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val compressedStream = serializerManager.wrapForCompression(spill.blockId, bufferedStream) - serInstance.deserializeStream(compressedStream) + + val wrappedStream = serializerManager.wrapStream(spill.blockId, bufferedStream) + serInstance.deserializeStream(wrappedStream) } else { // No more batches left cleanup() diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index daeb4675ea5f53ed35286ac0ba26e38f46bb18fb..a96cd82382e2cd06d3ed0018d07d1fee9ca88536 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -86,7 +86,7 @@ public class UnsafeShuffleWriterSuite { @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep; - private final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> { + private final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> { @Override public OutputStream apply(OutputStream stream) { if (conf.getBoolean("spark.shuffle.compress", true)) { @@ -136,7 +136,7 @@ public class UnsafeShuffleWriterSuite { (File) args[1], (SerializerInstance) args[2], (Integer) args[3], - new CompressStream(), + new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index fc127f07c8d690a418fb4aebcb5e5e807e7539b1..33709b454c4c98044895c213179030a6d7c781d1 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -75,7 +75,7 @@ public abstract class AbstractBytesToBytesMapSuite { @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; - private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> { + private static final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> { @Override public OutputStream apply(OutputStream stream) { return stream; @@ -122,7 +122,7 @@ public abstract class AbstractBytesToBytesMapSuite { (File) args[1], (SerializerInstance) args[2], (Integer) args[3], - new CompressStream(), + new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 3ea99233fe17176e13188214fdb8323968f11728..a9cf8ff520ed432bfc5033d37b0ce15aa13d1071 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -88,7 +88,7 @@ public class UnsafeExternalSorterSuite { private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m"); - private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> { + private static final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> { @Override public OutputStream apply(OutputStream stream) { return stream; @@ -128,7 +128,7 @@ public class UnsafeExternalSorterSuite { (File) args[1], (SerializerInstance) args[2], (Integer) args[3], - new CompressStream(), + new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..81eb907ac7ba608b95a51b1c75549560ccdad164 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.security + +import java.security.PrivilegedExceptionAction + +import org.apache.hadoop.security.{Credentials, UserGroupInformation} + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ +import org.apache.spark.security.CryptoStreamUtils._ + +class CryptoStreamUtilsSuite extends SparkFunSuite { + val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup")) + + test("Crypto configuration conversion") { + val sparkKey1 = s"${SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX}a.b.c" + val sparkVal1 = "val1" + val cryptoKey1 = s"${COMMONS_CRYPTO_CONF_PREFIX}a.b.c" + + val sparkKey2 = SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.stripSuffix(".") + "A.b.c" + val sparkVal2 = "val2" + val cryptoKey2 = s"${COMMONS_CRYPTO_CONF_PREFIX}A.b.c" + val conf = new SparkConf() + conf.set(sparkKey1, sparkVal1) + conf.set(sparkKey2, sparkVal2) + val props = CryptoStreamUtils.toCryptoConf(conf) + assert(props.getProperty(cryptoKey1) === sparkVal1) + assert(!props.containsKey(cryptoKey2)) + } + + test("Shuffle encryption is disabled by default") { + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + val credentials = UserGroupInformation.getCurrentUser.getCredentials() + val conf = new SparkConf() + initCredentials(conf, credentials) + assert(credentials.getSecretKey(SPARK_IO_TOKEN) === null) + } + }) + } + + test("Shuffle encryption key length should be 128 by default") { + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + val credentials = UserGroupInformation.getCurrentUser.getCredentials() + val conf = new SparkConf() + conf.set(IO_ENCRYPTION_ENABLED, true) + initCredentials(conf, credentials) + var key = credentials.getSecretKey(SPARK_IO_TOKEN) + assert(key !== null) + val actual = key.length * (java.lang.Byte.SIZE) + assert(actual === 128) + } + }) + } + + test("Initial credentials with key length in 256") { + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + val credentials = UserGroupInformation.getCurrentUser.getCredentials() + val conf = new SparkConf() + conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 256) + conf.set(IO_ENCRYPTION_ENABLED, true) + initCredentials(conf, credentials) + var key = credentials.getSecretKey(SPARK_IO_TOKEN) + assert(key !== null) + val actual = key.length * (java.lang.Byte.SIZE) + assert(actual === 256) + } + }) + } + + test("Initial credentials with invalid key length") { + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + val credentials = UserGroupInformation.getCurrentUser.getCredentials() + val conf = new SparkConf() + conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 328) + conf.set(IO_ENCRYPTION_ENABLED, true) + val thrown = intercept[IllegalArgumentException] { + initCredentials(conf, credentials) + } + } + }) + } + + private[this] def initCredentials(conf: SparkConf, credentials: Credentials): Unit = { + if (conf.get(IO_ENCRYPTION_ENABLED)) { + SecurityManager.initIOEncryptionKey(conf, credentials) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 5132384a5ed7d003534872c38a239058a5bfeb98..ed9428820ff6c4b1c348f6d24cb603dca77c5e0f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -94,7 +94,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte args(1).asInstanceOf[File], args(2).asInstanceOf[SerializerInstance], args(3).asInstanceOf[Int], - compressStream = identity, + wrapStream = identity, syncWrites = false, args(4).asInstanceOf[ShuffleWriteMetrics], blockId = args(0).asInstanceOf[BlockId] diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 326271a7e2b23a74bfd389f5007969cad4b27a76..eaed0889ac36fb3c0443d4715a395b4186f0d5dc 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -27,6 +27,7 @@ commons-collections-3.2.2.jar commons-compiler-2.7.6.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar +commons-crypto-1.0.0.jar commons-dbcp-1.4.jar commons-digester-1.8.jar commons-httpclient-3.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 1ff6ecb7342bb13905d2993c38a8f13d3ec95969..d68a7f462ba7f7d48372933f119b918f7fdb255a 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -30,6 +30,7 @@ commons-collections-3.2.2.jar commons-compiler-2.7.6.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar +commons-crypto-1.0.0.jar commons-dbcp-1.4.jar commons-digester-1.8.jar commons-httpclient-3.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 68333849cf4c9a41fb0b6076ab33cbd9c34c741b..346f19767d3670aef441fc8e496db05bdbf670a5 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -30,6 +30,7 @@ commons-collections-3.2.2.jar commons-compiler-2.7.6.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar +commons-crypto-1.0.0.jar commons-dbcp-1.4.jar commons-digester-1.8.jar commons-httpclient-3.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 787d06c3512dba53a68d3427cf6c4a7071928ea8..6f4695f345a486f57c84ca07ed1850f5f652c81b 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -34,6 +34,7 @@ commons-collections-3.2.2.jar commons-compiler-2.7.6.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar +commons-crypto-1.0.0.jar commons-dbcp-1.4.jar commons-digester-1.8.jar commons-httpclient-3.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 386495bf1bbb1f6175b27fdf1d19c66c5c70ff14..7a86a8bd8884628c5bd135bd98cc43cbd847d78a 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -34,6 +34,7 @@ commons-collections-3.2.2.jar commons-compiler-2.7.6.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar +commons-crypto-1.0.0.jar commons-dbcp-1.4.jar commons-digester-1.8.jar commons-httpclient-3.1.jar diff --git a/docs/configuration.md b/docs/configuration.md index 2f801961050e1c9274f3a81aee6d7846dca2a97a..d0c76aaad0b35f29a14dbbdde119412803c2e74a 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -559,6 +559,29 @@ Apart from these, the following properties are also available, and may be useful <code>spark.io.compression.codec</code>. </td> </tr> +<tr> + <td><code>spark.io.encryption.enabled</code></td> + <td>false</td> + <td> + Enable IO encryption. Only supported in YARN mode. + </td> +</tr> +<tr> + <td><code>spark.io.encryption.keySizeBits</code></td> + <td>128</td> + <td> + IO encryption key size in bits. Supported values are 128, 192 and 256. + </td> +</tr> +<tr> + <td><code>spark.io.encryption.keygen.algorithm</code></td> + <td>HmacSHA1</td> + <td> + The algorithm to use when generating the IO encryption key. The supported algorithms are + described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm + Name Documentation. + </td> +</tr> </table> #### Spark UI diff --git a/pom.xml b/pom.xml index 74238db59ed8f78a96c526108b3b90d7af559e1b..2c265c1fa325e5f8580eba6d47396064a66c259f 100644 --- a/pom.xml +++ b/pom.xml @@ -180,6 +180,7 @@ <selenium.version>2.52.0</selenium.version> <paranamer.version>2.8</paranamer.version> <maven-antrun.version>1.8</maven-antrun.version> + <commons-crypto.version>1.0.0</commons-crypto.version> <test.java.home>${java.home}</test.java.home> <test.exclude.tags></test.exclude.tags> @@ -1825,6 +1826,17 @@ <artifactId>jline</artifactId> <version>${jline.version}</version> </dependency> + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-crypto</artifactId> + <version>${commons-crypto.version}</version> + <exclusions> + <exclusion> + <groupId>net.java.dev.jna</groupId> + <artifactId>jna</artifactId> + </exclusion> + </exclusions> + </dependency> </dependencies> </dependencyManagement> diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7fbbe91de94e58f3269c0ff9ace9f74c94a231a7..2398f0aea316a420355fba34d0aca14da7e57f87 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1003,6 +1003,10 @@ private[spark] class Client( val securityManager = new SecurityManager(sparkConf) amContainer.setApplicationACLs( YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava) + + if (sparkConf.get(IO_ENCRYPTION_ENABLED)) { + SecurityManager.initIOEncryptionKey(sparkConf, credentials) + } setupSecurityToken(amContainer) amContainer diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..1c60315b21ae80d4b01c5a8ab577f1e8d0b999b3 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import java.io._ +import java.nio.charset.StandardCharsets +import java.security.PrivilegedExceptionAction +import java.util.UUID + +import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers} + +import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.config._ +import org.apache.spark.serializer._ +import org.apache.spark.storage._ + +class IOEncryptionSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll + with BeforeAndAfterEach { + private[this] val blockId = new TempShuffleBlockId(UUID.randomUUID()) + private[this] val conf = new SparkConf() + private[this] val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup")) + private[this] val serializer = new KryoSerializer(conf) + + override def beforeAll(): Unit = { + System.setProperty("SPARK_YARN_MODE", "true") + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + conf.set(IO_ENCRYPTION_ENABLED, true) + val creds = new Credentials() + SecurityManager.initIOEncryptionKey(conf, creds) + SparkHadoopUtil.get.addCurrentUserCredentials(creds) + } + }) + } + + override def afterAll(): Unit = { + SparkEnv.set(null) + System.clearProperty("SPARK_YARN_MODE") + } + + override def beforeEach(): Unit = { + super.beforeEach() + } + + override def afterEach(): Unit = { + super.afterEach() + conf.set("spark.shuffle.compress", false.toString) + conf.set("spark.shuffle.spill.compress", false.toString) + } + + test("IO encryption read and write") { + ugi.doAs(new PrivilegedExceptionAction[Unit] { + override def run(): Unit = { + conf.set(IO_ENCRYPTION_ENABLED, true) + conf.set("spark.shuffle.compress", false.toString) + conf.set("spark.shuffle.spill.compress", false.toString) + testYarnIOEncryptionWriteRead() + } + }) + } + + test("IO encryption read and write with shuffle compression enabled") { + ugi.doAs(new PrivilegedExceptionAction[Unit] { + override def run(): Unit = { + conf.set(IO_ENCRYPTION_ENABLED, true) + conf.set("spark.shuffle.compress", true.toString) + conf.set("spark.shuffle.spill.compress", true.toString) + testYarnIOEncryptionWriteRead() + } + }) + } + + private[this] def testYarnIOEncryptionWriteRead(): Unit = { + val plainStr = "hello world" + val outputStream = new ByteArrayOutputStream() + val serializerManager = new SerializerManager(serializer, conf) + val wrappedOutputStream = serializerManager.wrapStream(blockId, outputStream) + wrappedOutputStream.write(plainStr.getBytes(StandardCharsets.UTF_8)) + wrappedOutputStream.close() + + val encryptedBytes = outputStream.toByteArray + val encryptedStr = new String(encryptedBytes) + assert(plainStr !== encryptedStr) + + val inputStream = new ByteArrayInputStream(encryptedBytes) + val wrappedInputStream = serializerManager.wrapStream(blockId, inputStream) + val decryptedBytes = new Array[Byte](1024) + val len = wrappedInputStream.read(decryptedBytes) + val decryptedStr = new String(decryptedBytes, 0, len, StandardCharsets.UTF_8) + assert(decryptedStr === plainStr) + } +}