diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 6840a3ae831f0d0bf8ba0e640760878d15fe7008..a039d543c35e71da6ae6c6ea17834587adb7d6da 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -47,7 +47,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) private val blockHandler = newShuffleBlockHandler(transportConf) - private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler) + private val transportContext: TransportContext = + new TransportContext(transportConf, blockHandler, true) private var server: TransportServer = _ diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 43900e6f2c9726fa8dd3e9294d2ddb603c3d4481..1b64b863a9fe5c85c6ff30f0b67f4599e62e1553 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -59,15 +59,24 @@ public class TransportContext { private final TransportConf conf; private final RpcHandler rpcHandler; + private final boolean closeIdleConnections; private final MessageEncoder encoder; private final MessageDecoder decoder; public TransportContext(TransportConf conf, RpcHandler rpcHandler) { + this(conf, rpcHandler, false); + } + + public TransportContext( + TransportConf conf, + RpcHandler rpcHandler, + boolean closeIdleConnections) { this.conf = conf; this.rpcHandler = rpcHandler; this.encoder = new MessageEncoder(); this.decoder = new MessageDecoder(); + this.closeIdleConnections = closeIdleConnections; } /** @@ -144,7 +153,7 @@ public class TransportContext { TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, rpcHandler); return new TransportChannelHandler(client, responseHandler, requestHandler, - conf.connectionTimeoutMs()); + conf.connectionTimeoutMs(), closeIdleConnections); } public TransportConf getConf() { return conf; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 4952ffb44bb8b33a32dc3ff951752aef3f087c28..42a4f664e697c3e9a1b7864b360a525af62f713b 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -158,6 +158,16 @@ public class TransportClientFactory implements Closeable { } } + /** + * Create a completely new {@link TransportClient} to the given remote host / port + * But this connection is not pooled. + */ + public TransportClient createUnmanagedClient(String remoteHost, int remotePort) + throws IOException { + final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + return createClient(address); + } + /** Create a completely new {@link TransportClient} to the remote address. */ private TransportClient createClient(InetSocketAddress address) throws IOException { logger.debug("Creating new connection to " + address); diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 8e0ee709e38e3856f437eec3382bbfd6906db0e6..f8fcd1c3d7d76f1827a43338194d093e4452da22 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -55,16 +55,19 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message private final TransportResponseHandler responseHandler; private final TransportRequestHandler requestHandler; private final long requestTimeoutNs; + private final boolean closeIdleConnections; public TransportChannelHandler( TransportClient client, TransportResponseHandler responseHandler, TransportRequestHandler requestHandler, - long requestTimeoutMs) { + long requestTimeoutMs, + boolean closeIdleConnections) { this.client = client; this.responseHandler = responseHandler; this.requestHandler = requestHandler; this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000; + this.closeIdleConnections = closeIdleConnections; } public TransportClient getClient() { @@ -111,16 +114,21 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message IdleStateEvent e = (IdleStateEvent) evt; // See class comment for timeout semantics. In addition to ensuring we only timeout while // there are outstanding requests, we also do a secondary consistency check to ensure - // there's no race between the idle timeout and incrementing the numOutstandingRequests. - boolean hasInFlightRequests = responseHandler.numOutstandingRequests() > 0; + // there's no race between the idle timeout and incrementing the numOutstandingRequests + // (see SPARK-7003). boolean isActuallyOverdue = System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; - if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue) { - String address = NettyUtils.getRemoteAddress(ctx.channel()); - logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + - "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + - "is wrong.", address, requestTimeoutNs / 1000 / 1000); - ctx.close(); + if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { + if (responseHandler.numOutstandingRequests() > 0) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + ctx.close(); + } else if (closeIdleConnections) { + // While CloseIdleConnections is enable, we also close idle connection + ctx.close(); + } } } } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 35de5e57ccb98cdba6a5fdd5bbd88b419ce6dbb9..f4471374193060d05464a47c33a4be5445a2505d 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Collections; import java.util.HashSet; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; @@ -37,6 +38,7 @@ import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.ConfigProvider; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; @@ -177,4 +179,36 @@ public class TransportClientFactorySuite { assertFalse(c1.isActive()); assertFalse(c2.isActive()); } + + @Test + public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException { + TransportConf conf = new TransportConf(new ConfigProvider() { + + @Override + public String get(String name) { + if ("spark.shuffle.io.connectionTimeout".equals(name)) { + // We should make sure there is enough time for us to observe the channel is active + return "1s"; + } + String value = System.getProperty(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } + }); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); + TransportClientFactory factory = context.createClientFactory(); + try { + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertTrue(c1.isActive()); + long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds + while (c1.isActive() && System.currentTimeMillis() < expiredTime) { + Thread.sleep(10); + } + assertFalse(c1.isActive()); + } finally { + factory.close(); + } + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index ea6d248d66be36ad3bb305a93737d78940b60782..ef3a9dcc8711fcb96892be888fd5df308a882c6d 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -78,7 +78,7 @@ public class ExternalShuffleClient extends ShuffleClient { @Override public void init(String appId) { this.appId = appId; - TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); List<TransportClientBootstrap> bootstraps = Lists.newArrayList(); if (saslEnabled) { bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled)); @@ -137,9 +137,13 @@ public class ExternalShuffleClient extends ShuffleClient { String execId, ExecutorShuffleInfo executorInfo) throws IOException { checkInit(); - TransportClient client = clientFactory.createClient(host, port); - byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); - client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + TransportClient client = clientFactory.createUnmanagedClient(host, port); + try { + byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); + client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + } finally { + client.close(); + } } @Override