Skip to content
Snippets Groups Projects
Commit 4f15d94c authored by Junjie Chen's avatar Junjie Chen Committed by Marcelo Vanzin
Browse files

[SPARK-13331] AES support for over-the-wire encryption

## What changes were proposed in this pull request?

DIGEST-MD5 mechanism is used for SASL authentication and secure communication. DIGEST-MD5 mechanism supports 3DES, DES, and RC4 ciphers. However, 3DES, DES and RC4 are slow relatively.

AES provide better performance and security by design and is a replacement for 3DES according to NIST. Apache Common Crypto is a cryptographic library optimized with AES-NI, this patch employ Apache Common Crypto as enc/dec backend for SASL authentication and secure channel to improve spark RPC.
## How was this patch tested?

Unit tests and Integration test.

Author: Junjie Chen <junjie.j.chen@intel.com>

Closes #15172 from cjjnjust/shuffle_rpc_encrypt.
parent 5ddf6947
No related branches found
No related tags found
No related merge requests found
Showing with 689 additions and 37 deletions
......@@ -76,6 +76,10 @@
<artifactId>guava</artifactId>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-crypto</artifactId>
</dependency>
<!-- Test dependencies -->
<dependency>
......
......@@ -30,6 +30,8 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.sasl.aes.AesCipher;
import org.apache.spark.network.sasl.aes.AesConfigMessage;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.TransportConf;
......@@ -88,9 +90,26 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
throw new RuntimeException(
new SaslException("Encryption requests by negotiated non-encrypted connection."));
}
SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
if (conf.aesEncryptionEnabled()) {
// Generate a request config message to send to server.
AesConfigMessage configMessage = AesCipher.createConfigMessage(conf);
ByteBuffer buf = configMessage.encodeMessage();
// Encrypted the config message.
byte[] toEncrypt = JavaUtils.bufferToArray(buf);
ByteBuffer encrypted = ByteBuffer.wrap(saslClient.wrap(toEncrypt, 0, toEncrypt.length));
client.sendRpcSync(encrypted, conf.saslRTTimeoutMs());
AesCipher cipher = new AesCipher(configMessage, conf);
logger.info("Enabling AES cipher for client channel {}", client);
cipher.addToChannel(channel);
saslClient.dispose();
} else {
SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
}
saslClient = null;
logger.debug("Channel {} configured for SASL encryption.", client);
logger.debug("Channel {} configured for encryption.", client);
}
} catch (IOException ioe) {
throw new RuntimeException(ioe);
......
......@@ -29,6 +29,8 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.sasl.aes.AesCipher;
import org.apache.spark.network.sasl.aes.AesConfigMessage;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.JavaUtils;
......@@ -59,6 +61,7 @@ class SaslRpcHandler extends RpcHandler {
private SparkSaslServer saslServer;
private boolean isComplete;
private boolean isAuthenticated;
SaslRpcHandler(
TransportConf conf,
......@@ -71,6 +74,7 @@ class SaslRpcHandler extends RpcHandler {
this.secretKeyHolder = secretKeyHolder;
this.saslServer = null;
this.isComplete = false;
this.isAuthenticated = false;
}
@Override
......@@ -80,30 +84,31 @@ class SaslRpcHandler extends RpcHandler {
delegate.receive(client, message, callback);
return;
}
if (saslServer == null || !saslServer.isComplete()) {
ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
SaslMessage saslMessage;
try {
saslMessage = SaslMessage.decode(nettyBuf);
} finally {
nettyBuf.release();
}
ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
SaslMessage saslMessage;
try {
saslMessage = SaslMessage.decode(nettyBuf);
} finally {
nettyBuf.release();
}
if (saslServer == null) {
// First message in the handshake, setup the necessary state.
client.setClientId(saslMessage.appId);
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
conf.saslServerAlwaysEncrypt());
}
if (saslServer == null) {
// First message in the handshake, setup the necessary state.
client.setClientId(saslMessage.appId);
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
conf.saslServerAlwaysEncrypt());
}
byte[] response;
try {
response = saslServer.response(JavaUtils.bufferToArray(
saslMessage.body().nioByteBuffer()));
} catch (IOException ioe) {
throw new RuntimeException(ioe);
byte[] response;
try {
response = saslServer.response(JavaUtils.bufferToArray(
saslMessage.body().nioByteBuffer()));
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
callback.onSuccess(ByteBuffer.wrap(response));
}
callback.onSuccess(ByteBuffer.wrap(response));
// Setup encryption after the SASL response is sent, otherwise the client can't parse the
// response. It's ok to change the channel pipeline here since we are processing an incoming
......@@ -111,15 +116,42 @@ class SaslRpcHandler extends RpcHandler {
// method returns. This assumes that the code ensures, through other means, that no outbound
// messages are being written to the channel while negotiation is still going on.
if (saslServer.isComplete()) {
logger.debug("SASL authentication successful for channel {}", client);
isComplete = true;
if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
logger.debug("SASL authentication successful for channel {}", client);
complete(true);
return;
}
if (!conf.aesEncryptionEnabled()) {
logger.debug("Enabling encryption for channel {}", client);
SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
saslServer = null;
} else {
saslServer.dispose();
saslServer = null;
complete(false);
return;
}
// Extra negotiation should happen after authentication, so return directly while
// processing authenticate.
if (!isAuthenticated) {
logger.debug("SASL authentication successful for channel {}", client);
isAuthenticated = true;
return;
}
// Create AES cipher when it is authenticated
try {
byte[] encrypted = JavaUtils.bufferToArray(message);
ByteBuffer decrypted = ByteBuffer.wrap(saslServer.unwrap(encrypted, 0 , encrypted.length));
AesConfigMessage configMessage = AesConfigMessage.decodeMessage(decrypted);
AesCipher cipher = new AesCipher(configMessage, conf);
// Send response back to client to confirm that server accept config.
callback.onSuccess(JavaUtils.stringToBytes(AesCipher.TRANSFORM));
logger.info("Enabling AES cipher for Server channel {}", client);
cipher.addToChannel(channel);
complete(true);
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
}
}
......@@ -155,4 +187,17 @@ class SaslRpcHandler extends RpcHandler {
delegate.exceptionCaught(cause, client);
}
private void complete(boolean dispose) {
if (dispose) {
try {
saslServer.dispose();
} catch (RuntimeException e) {
logger.error("Error while disposing SASL server", e);
}
}
saslServer = null;
isComplete = true;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.sasl.aes;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.util.Properties;
import javax.crypto.spec.SecretKeySpec;
import javax.crypto.spec.IvParameterSpec;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.util.AbstractReferenceCounted;
import org.apache.commons.crypto.cipher.CryptoCipherFactory;
import org.apache.commons.crypto.random.CryptoRandom;
import org.apache.commons.crypto.random.CryptoRandomFactory;
import org.apache.commons.crypto.stream.CryptoInputStream;
import org.apache.commons.crypto.stream.CryptoOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.util.ByteArrayReadableChannel;
import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.TransportConf;
/**
* AES cipher for encryption and decryption.
*/
public class AesCipher {
private static final Logger logger = LoggerFactory.getLogger(AesCipher.class);
public static final String ENCRYPTION_HANDLER_NAME = "AesEncryption";
public static final String DECRYPTION_HANDLER_NAME = "AesDecryption";
public static final int STREAM_BUFFER_SIZE = 1024 * 32;
public static final String TRANSFORM = "AES/CTR/NoPadding";
private final SecretKeySpec inKeySpec;
private final IvParameterSpec inIvSpec;
private final SecretKeySpec outKeySpec;
private final IvParameterSpec outIvSpec;
private final Properties properties;
public AesCipher(AesConfigMessage configMessage, TransportConf conf) throws IOException {
this.properties = CryptoStreamUtils.toCryptoConf(conf);
this.inKeySpec = new SecretKeySpec(configMessage.inKey, "AES");
this.inIvSpec = new IvParameterSpec(configMessage.inIv);
this.outKeySpec = new SecretKeySpec(configMessage.outKey, "AES");
this.outIvSpec = new IvParameterSpec(configMessage.outIv);
}
/**
* Create AES crypto output stream
* @param ch The underlying channel to write out.
* @return Return output crypto stream for encryption.
* @throws IOException
*/
private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException {
return new CryptoOutputStream(TRANSFORM, properties, ch, outKeySpec, outIvSpec);
}
/**
* Create AES crypto input stream
* @param ch The underlying channel used to read data.
* @return Return input crypto stream for decryption.
* @throws IOException
*/
private CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException {
return new CryptoInputStream(TRANSFORM, properties, ch, inKeySpec, inIvSpec);
}
/**
* Add handlers to channel
* @param ch the channel for adding handlers
* @throws IOException
*/
public void addToChannel(Channel ch) throws IOException {
ch.pipeline()
.addFirst(ENCRYPTION_HANDLER_NAME, new AesEncryptHandler(this))
.addFirst(DECRYPTION_HANDLER_NAME, new AesDecryptHandler(this));
}
/**
* Create the configuration message
* @param conf is the local transport configuration.
* @return Config message for sending.
*/
public static AesConfigMessage createConfigMessage(TransportConf conf) {
int keySize = conf.aesCipherKeySize();
Properties properties = CryptoStreamUtils.toCryptoConf(conf);
try {
int paramLen = CryptoCipherFactory.getCryptoCipher(AesCipher.TRANSFORM, properties)
.getBlockSize();
byte[] inKey = new byte[keySize];
byte[] outKey = new byte[keySize];
byte[] inIv = new byte[paramLen];
byte[] outIv = new byte[paramLen];
CryptoRandom random = CryptoRandomFactory.getCryptoRandom(properties);
random.nextBytes(inKey);
random.nextBytes(outKey);
random.nextBytes(inIv);
random.nextBytes(outIv);
return new AesConfigMessage(inKey, inIv, outKey, outIv);
} catch (Exception e) {
logger.error("AES config error", e);
throw Throwables.propagate(e);
}
}
/**
* CryptoStreamUtils is used to convert config from TransportConf to AES Crypto config.
*/
private static class CryptoStreamUtils {
public static Properties toCryptoConf(TransportConf conf) {
Properties props = new Properties();
if (conf.aesCipherClass() != null) {
props.setProperty(CryptoCipherFactory.CLASSES_KEY, conf.aesCipherClass());
}
return props;
}
}
private static class AesEncryptHandler extends ChannelOutboundHandlerAdapter {
private final ByteArrayWritableChannel byteChannel;
private final CryptoOutputStream cos;
AesEncryptHandler(AesCipher cipher) throws IOException {
byteChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE);
cos = cipher.createOutputStream(byteChannel);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
ctx.write(new EncryptedMessage(cos, msg, byteChannel), promise);
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
try {
cos.close();
} finally {
super.close(ctx, promise);
}
}
}
private static class AesDecryptHandler extends ChannelInboundHandlerAdapter {
private final CryptoInputStream cis;
private final ByteArrayReadableChannel byteChannel;
AesDecryptHandler(AesCipher cipher) throws IOException {
byteChannel = new ByteArrayReadableChannel();
cis = cipher.createInputStream(byteChannel);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
byteChannel.feedData((ByteBuf) data);
byte[] decryptedData = new byte[byteChannel.readableBytes()];
int offset = 0;
while (offset < decryptedData.length) {
offset += cis.read(decryptedData, offset, decryptedData.length - offset);
}
ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length));
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
try {
cis.close();
} finally {
super.channelInactive(ctx);
}
}
}
private static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion {
private final boolean isByteBuf;
private final ByteBuf buf;
private final FileRegion region;
private long transferred;
private CryptoOutputStream cos;
// Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has
// to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data
// from upper handler, another is used to store encrypted data.
private ByteArrayWritableChannel byteEncChannel;
private ByteArrayWritableChannel byteRawChannel;
private ByteBuffer currentEncrypted;
EncryptedMessage(CryptoOutputStream cos, Object msg, ByteArrayWritableChannel ch) {
Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion,
"Unrecognized message type: %s", msg.getClass().getName());
this.isByteBuf = msg instanceof ByteBuf;
this.buf = isByteBuf ? (ByteBuf) msg : null;
this.region = isByteBuf ? null : (FileRegion) msg;
this.transferred = 0;
this.byteRawChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE);
this.cos = cos;
this.byteEncChannel = ch;
}
@Override
public long count() {
return isByteBuf ? buf.readableBytes() : region.count();
}
@Override
public long position() {
return 0;
}
@Override
public long transfered() {
return transferred;
}
@Override
public long transferTo(WritableByteChannel target, long position) throws IOException {
Preconditions.checkArgument(position == transfered(), "Invalid position.");
do {
if (currentEncrypted == null) {
encryptMore();
}
int bytesWritten = currentEncrypted.remaining();
target.write(currentEncrypted);
bytesWritten -= currentEncrypted.remaining();
transferred += bytesWritten;
if (!currentEncrypted.hasRemaining()) {
currentEncrypted = null;
byteEncChannel.reset();
}
} while (transferred < count());
return transferred;
}
private void encryptMore() throws IOException {
byteRawChannel.reset();
if (isByteBuf) {
int copied = byteRawChannel.write(buf.nioBuffer());
buf.skipBytes(copied);
} else {
region.transferTo(byteRawChannel, region.transfered());
}
cos.write(byteRawChannel.getData(), 0, byteRawChannel.length());
cos.flush();
currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(),
0, byteEncChannel.length());
}
@Override
protected void deallocate() {
byteRawChannel.reset();
byteEncChannel.reset();
if (region != null) {
region.release();
}
if (buf != null) {
buf.release();
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.sasl.aes;
import java.nio.ByteBuffer;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.apache.spark.network.protocol.Encodable;
import org.apache.spark.network.protocol.Encoders;
/**
* The AES cipher options for encryption negotiation.
*/
public class AesConfigMessage implements Encodable {
/** Serialization tag used to catch incorrect payloads. */
private static final byte TAG_BYTE = (byte) 0xEB;
public byte[] inKey;
public byte[] outKey;
public byte[] inIv;
public byte[] outIv;
public AesConfigMessage(byte[] inKey, byte[] inIv, byte[] outKey, byte[] outIv) {
if (inKey == null || inIv == null || outKey == null || outIv == null) {
throw new IllegalArgumentException("Cipher Key or IV must not be null!");
}
this.inKey = inKey;
this.inIv = inIv;
this.outKey = outKey;
this.outIv = outIv;
}
@Override
public int encodedLength() {
return 1 +
Encoders.ByteArrays.encodedLength(inKey) + Encoders.ByteArrays.encodedLength(outKey) +
Encoders.ByteArrays.encodedLength(inIv) + Encoders.ByteArrays.encodedLength(outIv);
}
@Override
public void encode(ByteBuf buf) {
buf.writeByte(TAG_BYTE);
Encoders.ByteArrays.encode(buf, inKey);
Encoders.ByteArrays.encode(buf, inIv);
Encoders.ByteArrays.encode(buf, outKey);
Encoders.ByteArrays.encode(buf, outIv);
}
/**
* Encode the config message.
* @return ByteBuffer which contains encoded config message.
*/
public ByteBuffer encodeMessage(){
ByteBuffer buf = ByteBuffer.allocate(encodedLength());
ByteBuf wrappedBuf = Unpooled.wrappedBuffer(buf);
wrappedBuf.clear();
encode(wrappedBuf);
return buf;
}
/**
* Decode the config message from buffer
* @param buffer the buffer contain encoded config message
* @return config message
*/
public static AesConfigMessage decodeMessage(ByteBuffer buffer) {
ByteBuf buf = Unpooled.wrappedBuffer(buffer);
if (buf.readByte() != TAG_BYTE) {
throw new IllegalStateException("Expected AesConfigMessage, received something else"
+ " (maybe your client does not have AES enabled?)");
}
byte[] outKey = Encoders.ByteArrays.decode(buf);
byte[] outIv = Encoders.ByteArrays.decode(buf);
byte[] inKey = Encoders.ByteArrays.decode(buf);
byte[] inIv = Encoders.ByteArrays.decode(buf);
return new AesConfigMessage(inKey, inIv, outKey, outIv);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.util;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import io.netty.buffer.ByteBuf;
public class ByteArrayReadableChannel implements ReadableByteChannel {
private ByteBuf data;
public int readableBytes() {
return data.readableBytes();
}
public void feedData(ByteBuf buf) {
data = buf;
}
@Override
public int read(ByteBuffer dst) throws IOException {
int totalRead = 0;
while (data.readableBytes() > 0 && dst.remaining() > 0) {
int bytesToRead = Math.min(data.readableBytes(), dst.remaining());
dst.put(data.readSlice(bytesToRead).nioBuffer());
totalRead += bytesToRead;
}
if (data.readableBytes() == 0) {
data.release();
}
return totalRead;
}
@Override
public void close() throws IOException {
}
@Override
public boolean isOpen() {
return true;
}
}
......@@ -18,6 +18,7 @@
package org.apache.spark.network.util;
import com.google.common.primitives.Ints;
import org.apache.commons.crypto.cipher.CryptoCipherFactory;
/**
* A central location that tracks all the settings we expose to users.
......@@ -175,4 +176,25 @@ public class TransportConf {
return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
}
/**
* The trigger for enabling AES encryption.
*/
public boolean aesEncryptionEnabled() {
return conf.getBoolean("spark.authenticate.encryption.aes.enabled", false);
}
/**
* The implementation class for crypto cipher
*/
public String aesCipherClass() {
return conf.get("spark.authenticate.encryption.aes.cipher.class", null);
}
/**
* The bytes of AES cipher key which is effective when AES cipher is enabled. Notice that
* the length should be 16, 24 or 32 bytes.
*/
public int aesCipherKeySize() {
return conf.getInt("spark.authenticate.encryption.aes.cipher.keySize", 16);
}
}
......@@ -53,6 +53,7 @@ import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.sasl.aes.AesCipher;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
......@@ -149,7 +150,7 @@ public class SparkSaslSuite {
.when(rpcHandler)
.receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class));
SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false, false);
try {
ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
TimeUnit.SECONDS.toMillis(10));
......@@ -275,7 +276,7 @@ public class SparkSaslSuite {
new Random().nextBytes(data);
Files.write(data, file);
ctx = new SaslTestCtx(rpcHandler, true, false);
ctx = new SaslTestCtx(rpcHandler, true, false, false);
final CountDownLatch lock = new CountDownLatch(1);
......@@ -317,7 +318,7 @@ public class SparkSaslSuite {
SaslTestCtx ctx = null;
try {
ctx = new SaslTestCtx(mock(RpcHandler.class), false, false);
ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, false);
fail("Should have failed to connect without encryption.");
} catch (Exception e) {
assertTrue(e.getCause() instanceof SaslException);
......@@ -336,7 +337,7 @@ public class SparkSaslSuite {
// able to understand RPCs sent to it and thus close the connection.
SaslTestCtx ctx = null;
try {
ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
ctx = new SaslTestCtx(mock(RpcHandler.class), true, true, false);
ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
TimeUnit.SECONDS.toMillis(10));
fail("Should have failed to send RPC to server.");
......@@ -374,6 +375,69 @@ public class SparkSaslSuite {
}
}
@Test
public void testAesEncryption() throws Exception {
final AtomicReference<ManagedBuffer> response = new AtomicReference<>();
final File file = File.createTempFile("sasltest", ".txt");
SaslTestCtx ctx = null;
try {
final TransportConf conf = new TransportConf("rpc", new SystemPropertyConfigProvider());
final TransportConf spyConf = spy(conf);
doReturn(true).when(spyConf).aesEncryptionEnabled();
StreamManager sm = mock(StreamManager.class);
when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer<ManagedBuffer>() {
@Override
public ManagedBuffer answer(InvocationOnMock invocation) {
return new FileSegmentManagedBuffer(spyConf, file, 0, file.length());
}
});
RpcHandler rpcHandler = mock(RpcHandler.class);
when(rpcHandler.getStreamManager()).thenReturn(sm);
byte[] data = new byte[256 * 1024 * 1024];
new Random().nextBytes(data);
Files.write(data, file);
ctx = new SaslTestCtx(rpcHandler, true, false, true);
final Object lock = new Object();
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) {
response.set((ManagedBuffer) invocation.getArguments()[1]);
response.get().retain();
synchronized (lock) {
lock.notifyAll();
}
return null;
}
}).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
synchronized (lock) {
ctx.client.fetchChunk(0, 0, callback);
lock.wait(10 * 1000);
}
verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
assertTrue(Arrays.equals(data, received));
} finally {
file.delete();
if (ctx != null) {
ctx.close();
}
if (response.get() != null) {
response.get().release();
}
}
}
private static class SaslTestCtx {
final TransportClient client;
......@@ -386,18 +450,28 @@ public class SparkSaslSuite {
SaslTestCtx(
RpcHandler rpcHandler,
boolean encrypt,
boolean disableClientEncryption)
boolean disableClientEncryption,
boolean aesEnable)
throws Exception {
TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
if (aesEnable) {
conf = spy(conf);
doReturn(true).when(conf).aesEncryptionEnabled();
}
SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
when(keyHolder.getSaslUser(anyString())).thenReturn("user");
when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
TransportContext ctx = new TransportContext(conf, rpcHandler);
this.checker = new EncryptionCheckerBootstrap();
String encryptHandlerName = aesEnable ? AesCipher.ENCRYPTION_HANDLER_NAME :
SaslEncryption.ENCRYPTION_HANDLER_NAME;
this.checker = new EncryptionCheckerBootstrap(encryptHandlerName);
this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder),
checker));
......@@ -437,13 +511,18 @@ public class SparkSaslSuite {
implements TransportServerBootstrap {
boolean foundEncryptionHandler;
String encryptHandlerName;
public EncryptionCheckerBootstrap(String encryptHandlerName) {
this.encryptHandlerName = encryptHandlerName;
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
if (!foundEncryptionHandler) {
foundEncryptionHandler =
ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null;
ctx.channel().pipeline().get(encryptHandlerName) != null;
}
ctx.write(msg, promise);
}
......
......@@ -1529,6 +1529,32 @@ Apart from these, the following properties are also available, and may be useful
currently supported by the external shuffle service.
</td>
</tr>
<tr>
<td><code>spark.authenticate.encryption.aes.enabled</code></td>
<td>false</td>
<td>
Enable AES for over-the-wire encryption
</td>
</tr>
<tr>
<td><code>spark.authenticate.encryption.aes.cipher.keySize</code></td>
<td>16</td>
<td>
The bytes of AES cipher key which is effective when AES cipher is enabled. AES
works with 16, 24 and 32 bytes keys.
</td>
</tr>
<tr>
<td><code>spark.authenticate.encryption.aes.cipher.class</code></td>
<td>null</td>
<td>
Specify the underlying implementation class of crypto cipher. Set null here to use default.
In order to use OpenSslCipher users should install openssl. Currently, there are two cipher
classes available in Commons Crypto library:
org.apache.commons.crypto.cipher.OpenSslCipher
org.apache.commons.crypto.cipher.JceCipher
</td>
</tr>
<tr>
<td><code>spark.core.connection.ack.wait.timeout</code></td>
<td>60s</td>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment