diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 199365ad925a3b119ee94e46a345cb589228a2ae..87fe56315203ed2992a8c56c3b9a42c1bdebc6f2 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -21,7 +21,6 @@ import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} import java.security.{KeyStore, SecureRandom} import java.security.cert.X509Certificate -import javax.crypto.KeyGenerator import javax.net.ssl._ import com.google.common.hash.HashCodes @@ -33,7 +32,6 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.sasl.SecretKeyHolder -import org.apache.spark.security.CryptoStreamUtils._ import org.apache.spark.util.Utils /** @@ -185,7 +183,9 @@ import org.apache.spark.util.Utils * setting `spark.ssl.useNodeLocalConf` to `true`. */ -private[spark] class SecurityManager(sparkConf: SparkConf) +private[spark] class SecurityManager( + sparkConf: SparkConf, + ioEncryptionKey: Option[Array[Byte]] = None) extends Logging with SecretKeyHolder { import SecurityManager._ @@ -415,6 +415,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) logInfo("Changing acls enabled to: " + aclsOn) } + def getIOEncryptionKey(): Option[Array[Byte]] = ioEncryptionKey + /** * Generates or looks up the secret key. * @@ -559,19 +561,4 @@ private[spark] object SecurityManager { // key used to store the spark secret in the Hadoop UGI val SECRET_LOOKUP_KEY = "sparkCookie" - /** - * Setup the cryptographic key used by IO encryption in credentials. The key is generated using - * [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]]. - */ - def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = { - if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) { - val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS) - val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM) - val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm) - keyGen.init(keyLen) - - val ioKey = keyGen.generateKey() - credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded) - } - } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1261e3e735761bf31d1ea20115ffc02eeeaec297..a159a170ebc5050e4aad338b84a90e1193a015ac 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -422,10 +422,6 @@ class SparkContext(config: SparkConf) extends Logging { } if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") - if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) { - throw new SparkException("IO encryption is only supported in YARN mode, please disable it " + - s"by setting ${IO_ENCRYPTION_ENABLED.key} to false") - } // "_jobProgressListener" should be set up before creating SparkEnv because when creating // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 1ffeb129880f96ec20df84b81e629a89e49aae98..1296386ac9bd3ad7f4062c292088fd991bfcb852 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -36,6 +36,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ @@ -165,15 +166,20 @@ object SparkEnv extends Logging { val bindAddress = conf.get(DRIVER_BIND_ADDRESS) val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS) val port = conf.get("spark.driver.port").toInt + val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) { + Some(CryptoStreamUtils.createKey(conf)) + } else { + None + } create( conf, SparkContext.DRIVER_IDENTIFIER, bindAddress, advertiseAddress, port, - isDriver = true, - isLocal = isLocal, - numUsableCores = numCores, + isLocal, + numCores, + ioEncryptionKey, listenerBus = listenerBus, mockOutputCommitCoordinator = mockOutputCommitCoordinator ) @@ -189,6 +195,7 @@ object SparkEnv extends Logging { hostname: String, port: Int, numCores: Int, + ioEncryptionKey: Option[Array[Byte]], isLocal: Boolean): SparkEnv = { val env = create( conf, @@ -196,9 +203,9 @@ object SparkEnv extends Logging { hostname, hostname, port, - isDriver = false, - isLocal = isLocal, - numUsableCores = numCores + isLocal, + numCores, + ioEncryptionKey ) SparkEnv.set(env) env @@ -213,18 +220,26 @@ object SparkEnv extends Logging { bindAddress: String, advertiseAddress: String, port: Int, - isDriver: Boolean, isLocal: Boolean, numUsableCores: Int, + ioEncryptionKey: Option[Array[Byte]], listenerBus: LiveListenerBus = null, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { + val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER + // Listener bus is only used on the driver if (isDriver) { assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!") } - val securityManager = new SecurityManager(conf) + val securityManager = new SecurityManager(conf, ioEncryptionKey) + ioEncryptionKey.foreach { _ => + if (!securityManager.isSaslEncryptionEnabled()) { + logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " + + "wire.") + } + } val systemName = if (isDriver) driverSystemName else executorSystemName val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf, @@ -270,7 +285,7 @@ object SparkEnv extends Logging { "spark.serializer", "org.apache.spark.serializer.JavaSerializer") logDebug(s"Using serializer: ${serializer.getClass}") - val serializerManager = new SerializerManager(serializer, conf) + val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey) val closureSerializer = new JavaSerializer(conf) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 7eec4ae64f29699d782e499887ec33772384c723..92a27902c66965dfebeafac101912c358e2944b9 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -200,8 +200,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { new SecurityManager(executorConf), clientMode = true) val driver = fetcher.setupEndpointRefByURI(driverUrl) - val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++ - Seq[(String, String)](("spark.app.id", appId)) + val cfg = driver.askWithRetry[SparkAppConfig](RetrieveSparkAppConfig) + val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() // Create SparkEnv using properties we fetched from the driver. @@ -221,7 +221,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } val env = SparkEnv.createExecutorEnv( - driverConf, executorId, hostname, port, cores, isLocal = false) + driverConf, executorId, hostname, port, cores, cfg.ioEncryptionKey, isLocal = false) env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index edc8aac5d1515ba8a32f1764fed17fd76ef7e03b..0a4f19d76073e158b4897db6b1bccd533eed4d92 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -28,7 +28,12 @@ private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable private[spark] object CoarseGrainedClusterMessages { - case object RetrieveSparkProps extends CoarseGrainedClusterMessage + case object RetrieveSparkAppConfig extends CoarseGrainedClusterMessage + + case class SparkAppConfig( + sparkProperties: Seq[(String, String)], + ioEncryptionKey: Option[Array[Byte]]) + extends CoarseGrainedClusterMessage case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 10d55c87fb8de00b9134e43ff8a137f9fb65c4a0..3452487e72e8843746446d3010a6d769f316b1c5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -206,8 +206,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp removeExecutor(executorId, reason) context.reply(true) - case RetrieveSparkProps => - context.reply(sparkProperties) + case RetrieveSparkAppConfig => + val reply = SparkAppConfig(sparkProperties, + SparkEnv.get.securityManager.getIOEncryptionKey()) + context.reply(reply) } // Make fake resource offers on all executors diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index f41fc38be20805c4c0c040bb1327770f5aa2141e..8e3436f13480db28f7a627826cbf237b8ea5eb54 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -18,14 +18,13 @@ package org.apache.spark.security import java.io.{InputStream, OutputStream} import java.util.Properties +import javax.crypto.KeyGenerator import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} import org.apache.commons.crypto.random._ import org.apache.commons.crypto.stream._ -import org.apache.hadoop.io.Text import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -33,10 +32,6 @@ import org.apache.spark.internal.config._ * A util class for manipulating IO encryption and decryption streams. */ private[spark] object CryptoStreamUtils extends Logging { - /** - * Constants and variables for spark IO encryption - */ - val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN") // The initialization vector length in bytes. val IV_LENGTH_IN_BYTES = 16 @@ -50,12 +45,11 @@ private[spark] object CryptoStreamUtils extends Logging { */ def createCryptoOutputStream( os: OutputStream, - sparkConf: SparkConf): OutputStream = { + sparkConf: SparkConf, + key: Array[Byte]): OutputStream = { val properties = toCryptoConf(sparkConf) val iv = createInitializationVector(properties) os.write(iv) - val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() - val key = credentials.getSecretKey(SPARK_IO_TOKEN) val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) new CryptoOutputStream(transformationStr, properties, os, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) @@ -66,12 +60,11 @@ private[spark] object CryptoStreamUtils extends Logging { */ def createCryptoInputStream( is: InputStream, - sparkConf: SparkConf): InputStream = { + sparkConf: SparkConf, + key: Array[Byte]): InputStream = { val properties = toCryptoConf(sparkConf) val iv = new Array[Byte](IV_LENGTH_IN_BYTES) is.read(iv, 0, iv.length) - val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() - val key = credentials.getSecretKey(SPARK_IO_TOKEN) val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) new CryptoInputStream(transformationStr, properties, is, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) @@ -91,6 +84,17 @@ private[spark] object CryptoStreamUtils extends Logging { props } + /** + * Creates a new encryption key. + */ + def createKey(conf: SparkConf): Array[Byte] = { + val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS) + val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM) + val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm) + keyGen.init(keyLen) + keyGen.generateKey().getEncoded() + } + /** * This method to generate an IV (Initialization Vector) using secure random. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 2156d576f18747d50897f53b3d6382878e47a33b..ef8432ec0834ae593c2a7abd56c0a1391ea6dc46 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -33,7 +33,12 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea * Component which configures serialization, compression and encryption for various Spark * components, including automatic selection of which [[Serializer]] to use for shuffles. */ -private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) { +private[spark] class SerializerManager( + defaultSerializer: Serializer, + conf: SparkConf, + encryptionKey: Option[Array[Byte]]) { + + def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None) private[this] val kryoSerializer = new KryoSerializer(conf) @@ -63,9 +68,6 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar // Whether to compress shuffle output temporarily spilled to disk private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - // Whether to enable IO encryption - private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED) - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark * program could be using a user-defined codec in a third party jar, which is loaded in @@ -125,14 +127,18 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar * Wrap an input stream for encryption if shuffle encryption is enabled */ private[this] def wrapForEncryption(s: InputStream): InputStream = { - if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s + encryptionKey + .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) } + .getOrElse(s) } /** * Wrap an output stream for encryption if shuffle encryption is enabled */ private[this] def wrapForEncryption(s: OutputStream): OutputStream = { - if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s + encryptionKey + .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) } + .getOrElse(s) } /** diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 81eb907ac7ba608b95a51b1c75549560ccdad164..a61ec74c7df8b8369f0938c1ee8798c90ce547d0 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -16,18 +16,21 @@ */ package org.apache.spark.security -import java.security.PrivilegedExceptionAction +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.UUID -import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import com.google.common.io.ByteStreams -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.internal.config._ import org.apache.spark.security.CryptoStreamUtils._ +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} +import org.apache.spark.storage.TempShuffleBlockId class CryptoStreamUtilsSuite extends SparkFunSuite { - val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup")) - test("Crypto configuration conversion") { + test("crypto configuration conversion") { val sparkKey1 = s"${SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX}a.b.c" val sparkVal1 = "val1" val cryptoKey1 = s"${COMMONS_CRYPTO_CONF_PREFIX}a.b.c" @@ -43,65 +46,85 @@ class CryptoStreamUtilsSuite extends SparkFunSuite { assert(!props.containsKey(cryptoKey2)) } - test("Shuffle encryption is disabled by default") { - ugi.doAs(new PrivilegedExceptionAction[Unit]() { - override def run(): Unit = { - val credentials = UserGroupInformation.getCurrentUser.getCredentials() - val conf = new SparkConf() - initCredentials(conf, credentials) - assert(credentials.getSecretKey(SPARK_IO_TOKEN) === null) - } - }) + test("shuffle encryption key length should be 128 by default") { + val conf = createConf() + var key = CryptoStreamUtils.createKey(conf) + val actual = key.length * (java.lang.Byte.SIZE) + assert(actual === 128) } - test("Shuffle encryption key length should be 128 by default") { - ugi.doAs(new PrivilegedExceptionAction[Unit]() { - override def run(): Unit = { - val credentials = UserGroupInformation.getCurrentUser.getCredentials() - val conf = new SparkConf() - conf.set(IO_ENCRYPTION_ENABLED, true) - initCredentials(conf, credentials) - var key = credentials.getSecretKey(SPARK_IO_TOKEN) - assert(key !== null) - val actual = key.length * (java.lang.Byte.SIZE) - assert(actual === 128) - } - }) + test("create 256-bit key") { + val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "256") + var key = CryptoStreamUtils.createKey(conf) + val actual = key.length * (java.lang.Byte.SIZE) + assert(actual === 256) } - test("Initial credentials with key length in 256") { - ugi.doAs(new PrivilegedExceptionAction[Unit]() { - override def run(): Unit = { - val credentials = UserGroupInformation.getCurrentUser.getCredentials() - val conf = new SparkConf() - conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 256) - conf.set(IO_ENCRYPTION_ENABLED, true) - initCredentials(conf, credentials) - var key = credentials.getSecretKey(SPARK_IO_TOKEN) - assert(key !== null) - val actual = key.length * (java.lang.Byte.SIZE) - assert(actual === 256) - } - }) + test("create key with invalid length") { + intercept[IllegalArgumentException] { + val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "328") + CryptoStreamUtils.createKey(conf) + } } - test("Initial credentials with invalid key length") { - ugi.doAs(new PrivilegedExceptionAction[Unit]() { - override def run(): Unit = { - val credentials = UserGroupInformation.getCurrentUser.getCredentials() - val conf = new SparkConf() - conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 328) - conf.set(IO_ENCRYPTION_ENABLED, true) - val thrown = intercept[IllegalArgumentException] { - initCredentials(conf, credentials) - } - } - }) + test("serializer manager integration") { + val conf = createConf() + .set("spark.shuffle.compress", "true") + .set("spark.shuffle.spill.compress", "true") + + val plainStr = "hello world" + val blockId = new TempShuffleBlockId(UUID.randomUUID()) + val key = Some(CryptoStreamUtils.createKey(conf)) + val serializerManager = new SerializerManager(new JavaSerializer(conf), conf, + encryptionKey = key) + + val outputStream = new ByteArrayOutputStream() + val wrappedOutputStream = serializerManager.wrapStream(blockId, outputStream) + wrappedOutputStream.write(plainStr.getBytes(UTF_8)) + wrappedOutputStream.close() + + val encryptedBytes = outputStream.toByteArray + val encryptedStr = new String(encryptedBytes, UTF_8) + assert(plainStr !== encryptedStr) + + val inputStream = new ByteArrayInputStream(encryptedBytes) + val wrappedInputStream = serializerManager.wrapStream(blockId, inputStream) + val decryptedBytes = ByteStreams.toByteArray(wrappedInputStream) + val decryptedStr = new String(decryptedBytes, UTF_8) + assert(decryptedStr === plainStr) } - private[this] def initCredentials(conf: SparkConf, credentials: Credentials): Unit = { - if (conf.get(IO_ENCRYPTION_ENABLED)) { - SecurityManager.initIOEncryptionKey(conf, credentials) + test("encryption key propagation to executors") { + val conf = createConf().setAppName("Crypto Test").setMaster("local-cluster[1,1,1024]") + val sc = new SparkContext(conf) + try { + val content = "This is the content to be encrypted." + val encrypted = sc.parallelize(Seq(1)) + .map { str => + val bytes = new ByteArrayOutputStream() + val out = CryptoStreamUtils.createCryptoOutputStream(bytes, SparkEnv.get.conf, + SparkEnv.get.securityManager.getIOEncryptionKey().get) + out.write(content.getBytes(UTF_8)) + out.close() + bytes.toByteArray() + }.collect()(0) + + assert(content != encrypted) + + val in = CryptoStreamUtils.createCryptoInputStream(new ByteArrayInputStream(encrypted), + sc.conf, SparkEnv.get.securityManager.getIOEncryptionKey().get) + val decrypted = new String(ByteStreams.toByteArray(in), UTF_8) + assert(content === decrypted) + } finally { + sc.stop() } } + + private def createConf(extra: (String, String)*): SparkConf = { + val conf = new SparkConf() + extra.foreach { case (k, v) => conf.set(k, v) } + conf.set(IO_ENCRYPTION_ENABLED, true) + conf + } + } diff --git a/docs/configuration.md b/docs/configuration.md index c2329b411fc693c44d4598cb7bde4bbb8e708d67..a6ba6cf6ee7aa5a8d01667d413792f78440a2984 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -572,7 +572,8 @@ Apart from these, the following properties are also available, and may be useful <td><code>spark.io.encryption.enabled</code></td> <td>false</td> <td> - Enable IO encryption. Only supported in YARN mode. + Enable IO encryption. Currently supported by all modes except Mesos. It's recommended that RPC encryption + be enabled when using this feature. </td> </tr> <tr> diff --git a/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index 1937bd30bac51a24ff95a0c3d04f3d8408fadce5..ee9149ce0208b05aa8cbc71ed6ab51fdcfdc31a8 100644 --- a/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -75,7 +75,7 @@ private[spark] class MesosExecutorBackend val conf = new SparkConf(loadDefaults = true).setAll(properties) val port = conf.getInt("spark.executor.port", 0) val env = SparkEnv.createExecutorEnv( - conf, executorId, slaveInfo.getHostname, port, cpusPerTask, isLocal = false) + conf, executorId, slaveInfo.getHostname, port, cpusPerTask, None, isLocal = false) executor = new Executor( executorId, diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala index a849c4afa24f578d96e44a75283148dff93d37ab..ed29b346ba263bc781b26c39e4b65e8fbc03937b 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster.mesos import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.internal.config._ import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} /** @@ -37,6 +38,9 @@ private[spark] class MesosClusterManager extends ExternalClusterManager { override def createSchedulerBackend(sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { + require(!sc.conf.get(IO_ENCRYPTION_ENABLED), + "I/O encryption is currently not supported in Mesos.") + val mesosUrl = MESOS_REGEX.findFirstMatchIn(masterURL).get.group(1) val coarse = sc.conf.getBoolean("spark.mesos.coarse", defaultValue = true) if (coarse) { diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManagerSuite.scala b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManagerSuite.scala index 6fce06632c57ead0ea4431d22f4c415001cc95e1..a55855428b471a20cde26fd28fe6367dc5ba13ba 100644 --- a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManagerSuite.scala +++ b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManagerSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.scheduler.cluster.mesos -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark._ +import org.apache.spark.internal.config._ class MesosClusterManagerSuite extends SparkFunSuite with LocalSparkContext { def testURL(masterURL: String, expectedClass: Class[_], coarse: Boolean) { @@ -44,4 +45,12 @@ class MesosClusterManagerSuite extends SparkFunSuite with LocalSparkContext { classOf[MesosFineGrainedSchedulerBackend], coarse = false) } + + test("mesos with i/o encryption throws error") { + val se = intercept[SparkException] { + val conf = new SparkConf().setAppName("test").set(IO_ENCRYPTION_ENABLED, true) + sc = new SparkContext("mesos", "test", conf) + } + assert(se.getCause().isInstanceOf[IllegalArgumentException]) + } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index e77fa386dc933ef439eceea2abff3374c209d16a..2c7d9d6b3ed028f5129a709beda25c51e84fbf52 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1013,12 +1013,7 @@ private[spark] class Client( val securityManager = new SecurityManager(sparkConf) amContainer.setApplicationACLs( YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava) - - if (sparkConf.get(IO_ENCRYPTION_ENABLED)) { - SecurityManager.initIOEncryptionKey(sparkConf, credentials) - } setupSecurityToken(amContainer) - amContainer } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala deleted file mode 100644 index 1c60315b21ae80d4b01c5a8ab577f1e8d0b999b3..0000000000000000000000000000000000000000 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.yarn - -import java.io._ -import java.nio.charset.StandardCharsets -import java.security.PrivilegedExceptionAction -import java.util.UUID - -import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers} - -import org.apache.spark._ -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.internal.config._ -import org.apache.spark.serializer._ -import org.apache.spark.storage._ - -class IOEncryptionSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll - with BeforeAndAfterEach { - private[this] val blockId = new TempShuffleBlockId(UUID.randomUUID()) - private[this] val conf = new SparkConf() - private[this] val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup")) - private[this] val serializer = new KryoSerializer(conf) - - override def beforeAll(): Unit = { - System.setProperty("SPARK_YARN_MODE", "true") - ugi.doAs(new PrivilegedExceptionAction[Unit]() { - override def run(): Unit = { - conf.set(IO_ENCRYPTION_ENABLED, true) - val creds = new Credentials() - SecurityManager.initIOEncryptionKey(conf, creds) - SparkHadoopUtil.get.addCurrentUserCredentials(creds) - } - }) - } - - override def afterAll(): Unit = { - SparkEnv.set(null) - System.clearProperty("SPARK_YARN_MODE") - } - - override def beforeEach(): Unit = { - super.beforeEach() - } - - override def afterEach(): Unit = { - super.afterEach() - conf.set("spark.shuffle.compress", false.toString) - conf.set("spark.shuffle.spill.compress", false.toString) - } - - test("IO encryption read and write") { - ugi.doAs(new PrivilegedExceptionAction[Unit] { - override def run(): Unit = { - conf.set(IO_ENCRYPTION_ENABLED, true) - conf.set("spark.shuffle.compress", false.toString) - conf.set("spark.shuffle.spill.compress", false.toString) - testYarnIOEncryptionWriteRead() - } - }) - } - - test("IO encryption read and write with shuffle compression enabled") { - ugi.doAs(new PrivilegedExceptionAction[Unit] { - override def run(): Unit = { - conf.set(IO_ENCRYPTION_ENABLED, true) - conf.set("spark.shuffle.compress", true.toString) - conf.set("spark.shuffle.spill.compress", true.toString) - testYarnIOEncryptionWriteRead() - } - }) - } - - private[this] def testYarnIOEncryptionWriteRead(): Unit = { - val plainStr = "hello world" - val outputStream = new ByteArrayOutputStream() - val serializerManager = new SerializerManager(serializer, conf) - val wrappedOutputStream = serializerManager.wrapStream(blockId, outputStream) - wrappedOutputStream.write(plainStr.getBytes(StandardCharsets.UTF_8)) - wrappedOutputStream.close() - - val encryptedBytes = outputStream.toByteArray - val encryptedStr = new String(encryptedBytes) - assert(plainStr !== encryptedStr) - - val inputStream = new ByteArrayInputStream(encryptedBytes) - val wrappedInputStream = serializerManager.wrapStream(blockId, inputStream) - val decryptedBytes = new Array[Byte](1024) - val len = wrappedInputStream.read(decryptedBytes) - val decryptedStr = new String(decryptedBytes, 0, len, StandardCharsets.UTF_8) - assert(decryptedStr === plainStr) - } -}