Skip to content
Snippets Groups Projects
Commit 968ad972 authored by Aaron Davidson's avatar Aaron Davidson
Browse files

[SPARK-7003] Improve reliability of connection failure detection between Netty...

[SPARK-7003] Improve reliability of connection failure detection between Netty block transfer service endpoints

Currently we rely on the assumption that an exception will be raised and the channel closed if two endpoints cannot communicate over a Netty TCP channel. However, this guarantee does not hold in all network environments, and [SPARK-6962](https://issues.apache.org/jira/browse/SPARK-6962) seems to point to a case where only the server side of the connection detected a fault.

This patch improves robustness of fetch/rpc requests by having an explicit timeout in the transport layer which closes the connection if there is a period of inactivity while there are outstanding requests.

NB: This patch is actually only around 50 lines added if you exclude the testing-related code.

Author: Aaron Davidson <aaron@databricks.com>

Closes #5584 from aarondav/timeout and squashes the following commits:

8699680 [Aaron Davidson] Address Reynold's comments
37ce656 [Aaron Davidson] [SPARK-7003] Improve reliability of connection failure detection between Netty block transfer service endpoints
parent 1be20707
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,7 @@ import java.util.List;
import com.google.common.collect.Lists;
import io.netty.channel.Channel;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.timeout.IdleStateHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -106,6 +107,7 @@ public class TransportContext {
.addLast("encoder", encoder)
.addLast("frameDecoder", NettyUtils.createFrameDecoder())
.addLast("decoder", decoder)
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
// would require more logic to guarantee if this were not part of the same event loop.
.addLast("handler", channelHandler);
......@@ -126,7 +128,8 @@ public class TransportContext {
TransportClient client = new TransportClient(channel, responseHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
rpcHandler);
return new TransportChannelHandler(client, responseHandler, requestHandler);
return new TransportChannelHandler(client, responseHandler, requestHandler,
conf.connectionTimeoutMs());
}
public TransportConf getConf() { return conf; }
......
......@@ -20,8 +20,8 @@ package org.apache.spark.network.client;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -50,13 +50,18 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
private final Map<Long, RpcResponseCallback> outstandingRpcs;
/** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
private final AtomicLong timeOfLastRequestNs;
public TransportResponseHandler(Channel channel) {
this.channel = channel;
this.outstandingFetches = new ConcurrentHashMap<StreamChunkId, ChunkReceivedCallback>();
this.outstandingRpcs = new ConcurrentHashMap<Long, RpcResponseCallback>();
this.timeOfLastRequestNs = new AtomicLong(0);
}
public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
timeOfLastRequestNs.set(System.nanoTime());
outstandingFetches.put(streamChunkId, callback);
}
......@@ -65,6 +70,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
}
public void addRpcRequest(long requestId, RpcResponseCallback callback) {
timeOfLastRequestNs.set(System.nanoTime());
outstandingRpcs.put(requestId, callback);
}
......@@ -161,8 +167,12 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
}
/** Returns total number of outstanding requests (fetch requests + rpcs) */
@VisibleForTesting
public int numOutstandingRequests() {
return outstandingFetches.size() + outstandingRpcs.size();
}
/** Returns the time in nanoseconds of when the last request was sent out. */
public long getTimeOfLastRequestNs() {
return timeOfLastRequestNs.get();
}
}
......@@ -19,6 +19,8 @@ package org.apache.spark.network.server;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -40,6 +42,11 @@ import org.apache.spark.network.util.NettyUtils;
* Client.
* This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler,
* for the Client's responses to the Server's requests.
*
* This class also handles timeouts from a {@link io.netty.handler.timeout.IdleStateHandler}.
* We consider a connection timed out if there are outstanding fetch or RPC requests but no traffic
* on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not
* timeout if the client is continuously sending but getting no responses, for simplicity.
*/
public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {
private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);
......@@ -47,14 +54,17 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
private final TransportClient client;
private final TransportResponseHandler responseHandler;
private final TransportRequestHandler requestHandler;
private final long requestTimeoutNs;
public TransportChannelHandler(
TransportClient client,
TransportResponseHandler responseHandler,
TransportRequestHandler requestHandler) {
TransportRequestHandler requestHandler,
long requestTimeoutMs) {
this.client = client;
this.responseHandler = responseHandler;
this.requestHandler = requestHandler;
this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000;
}
public TransportClient getClient() {
......@@ -93,4 +103,25 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
responseHandler.handle((ResponseMessage) request);
}
}
/** Triggered based on events from an {@link io.netty.handler.timeout.IdleStateHandler}. */
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
IdleStateEvent e = (IdleStateEvent) evt;
// See class comment for timeout semantics. In addition to ensuring we only timeout while
// there are outstanding requests, we also do a secondary consistency check to ensure
// there's no race between the idle timeout and incrementing the numOutstandingRequests.
boolean hasInFlightRequests = responseHandler.numOutstandingRequests() > 0;
boolean isActuallyOverdue =
System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs;
if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue) {
String address = NettyUtils.getRemoteAddress(ctx.channel());
logger.error("Connection to {} has been quiet for {} ms while there are outstanding " +
"requests. Assuming connection is dead; please adjust spark.network.timeout if this " +
"is wrong.", address, requestTimeoutNs / 1000 / 1000);
ctx.close();
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.util;
import com.google.common.collect.Maps;
import java.util.Map;
import java.util.NoSuchElementException;
/** ConfigProvider based on a Map (copied in the constructor). */
public class MapConfigProvider extends ConfigProvider {
private final Map<String, String> config;
public MapConfigProvider(Map<String, String> config) {
this.config = Maps.newHashMap(config);
}
@Override
public String get(String name) {
String value = config.get(name);
if (value == null) {
throw new NoSuchElementException(name);
}
return value;
}
}
......@@ -98,7 +98,7 @@ public class NettyUtils {
return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8);
}
/** Returns the remote address on the channel or "&lt;remote address&gt;" if none exists. */
/** Returns the remote address on the channel or "&lt;unknown remote&gt;" if none exists. */
public static String getRemoteAddress(Channel channel) {
if (channel != null && channel.remoteAddress() != null) {
return channel.remoteAddress().toString();
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.Uninterruptibles;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.junit.*;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
/**
* Suite which ensures that requests that go without a response for the network timeout period are
* failed, and the connection closed.
*
* In this suite, we use 2 seconds as the connection timeout, with some slack given in the tests,
* to ensure stability in different test environments.
*/
public class RequestTimeoutIntegrationSuite {
private TransportServer server;
private TransportClientFactory clientFactory;
private StreamManager defaultManager;
private TransportConf conf;
// A large timeout that "shouldn't happen", for the sake of faulty tests not hanging forever.
private final int FOREVER = 60 * 1000;
@Before
public void setUp() throws Exception {
Map<String, String> configMap = Maps.newHashMap();
configMap.put("spark.shuffle.io.connectionTimeout", "2s");
conf = new TransportConf(new MapConfigProvider(configMap));
defaultManager = new StreamManager() {
@Override
public ManagedBuffer getChunk(long streamId, int chunkIndex) {
throw new UnsupportedOperationException();
}
};
}
@After
public void tearDown() {
if (server != null) {
server.close();
}
if (clientFactory != null) {
clientFactory.close();
}
}
// Basic suite: First request completes quickly, and second waits for longer than network timeout.
@Test
public void timeoutInactiveRequests() throws Exception {
final Semaphore semaphore = new Semaphore(1);
final byte[] response = new byte[16];
RpcHandler handler = new RpcHandler() {
@Override
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
try {
semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
callback.onSuccess(response);
} catch (InterruptedException e) {
// do nothing
}
}
@Override
public StreamManager getStreamManager() {
return defaultManager;
}
};
TransportContext context = new TransportContext(conf, handler);
server = context.createServer();
clientFactory = context.createClientFactory();
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
// First completes quickly (semaphore starts at 1).
TestCallback callback0 = new TestCallback();
synchronized (callback0) {
client.sendRpc(new byte[0], callback0);
callback0.wait(FOREVER);
assert (callback0.success.length == response.length);
}
// Second times out after 2 seconds, with slack. Must be IOException.
TestCallback callback1 = new TestCallback();
synchronized (callback1) {
client.sendRpc(new byte[0], callback1);
callback1.wait(4 * 1000);
assert (callback1.failure != null);
assert (callback1.failure instanceof IOException);
}
semaphore.release();
}
// A timeout will cause the connection to be closed, invalidating the current TransportClient.
// It should be the case that requesting a client from the factory produces a new, valid one.
@Test
public void timeoutCleanlyClosesClient() throws Exception {
final Semaphore semaphore = new Semaphore(0);
final byte[] response = new byte[16];
RpcHandler handler = new RpcHandler() {
@Override
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
try {
semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
callback.onSuccess(response);
} catch (InterruptedException e) {
// do nothing
}
}
@Override
public StreamManager getStreamManager() {
return defaultManager;
}
};
TransportContext context = new TransportContext(conf, handler);
server = context.createServer();
clientFactory = context.createClientFactory();
// First request should eventually fail.
TransportClient client0 =
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
TestCallback callback0 = new TestCallback();
synchronized (callback0) {
client0.sendRpc(new byte[0], callback0);
callback0.wait(FOREVER);
assert (callback0.failure instanceof IOException);
assert (!client0.isActive());
}
// Increment the semaphore and the second request should succeed quickly.
semaphore.release(2);
TransportClient client1 =
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
TestCallback callback1 = new TestCallback();
synchronized (callback1) {
client1.sendRpc(new byte[0], callback1);
callback1.wait(FOREVER);
assert (callback1.success.length == response.length);
assert (callback1.failure == null);
}
}
// The timeout is relative to the LAST request sent, which is kinda weird, but still.
// This test also makes sure the timeout works for Fetch requests as well as RPCs.
@Test
public void furtherRequestsDelay() throws Exception {
final byte[] response = new byte[16];
final StreamManager manager = new StreamManager() {
@Override
public ManagedBuffer getChunk(long streamId, int chunkIndex) {
Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS);
return new NioManagedBuffer(ByteBuffer.wrap(response));
}
};
RpcHandler handler = new RpcHandler() {
@Override
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
throw new UnsupportedOperationException();
}
@Override
public StreamManager getStreamManager() {
return manager;
}
};
TransportContext context = new TransportContext(conf, handler);
server = context.createServer();
clientFactory = context.createClientFactory();
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
// Send one request, which will eventually fail.
TestCallback callback0 = new TestCallback();
client.fetchChunk(0, 0, callback0);
Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS);
// Send a second request before the first has failed.
TestCallback callback1 = new TestCallback();
client.fetchChunk(0, 1, callback1);
Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS);
synchronized (callback0) {
// not complete yet, but should complete soon
assert (callback0.success == null && callback0.failure == null);
callback0.wait(2 * 1000);
assert (callback0.failure instanceof IOException);
}
synchronized (callback1) {
// failed at same time as previous
assert (callback0.failure instanceof IOException);
}
}
/**
* Callback which sets 'success' or 'failure' on completion.
* Additionally notifies all waiters on this callback when invoked.
*/
class TestCallback implements RpcResponseCallback, ChunkReceivedCallback {
byte[] success;
Throwable failure;
@Override
public void onSuccess(byte[] response) {
synchronized(this) {
success = response;
this.notifyAll();
}
}
@Override
public void onFailure(Throwable e) {
synchronized(this) {
failure = e;
this.notifyAll();
}
}
@Override
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
synchronized(this) {
try {
success = buffer.nioByteBuffer().array();
this.notifyAll();
} catch (IOException e) {
// weird
}
}
}
@Override
public void onFailure(int chunkIndex, Throwable e) {
synchronized(this) {
failure = e;
this.notifyAll();
}
}
}
}
......@@ -20,10 +20,11 @@ package org.apache.spark.network;
import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.NoSuchElementException;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import com.google.common.collect.Maps;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
......@@ -36,9 +37,9 @@ import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.util.ConfigProvider;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class TransportClientFactorySuite {
......@@ -70,16 +71,10 @@ public class TransportClientFactorySuite {
*/
private void testClientReuse(final int maxConnections, boolean concurrent)
throws IOException, InterruptedException {
TransportConf conf = new TransportConf(new ConfigProvider() {
@Override
public String get(String name) {
if (name.equals("spark.shuffle.io.numConnectionsPerPeer")) {
return Integer.toString(maxConnections);
} else {
throw new NoSuchElementException();
}
}
});
Map<String, String> configMap = Maps.newHashMap();
configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections));
TransportConf conf = new TransportConf(new MapConfigProvider(configMap));
RpcHandler rpcHandler = new NoOpRpcHandler();
TransportContext context = new TransportContext(conf, rpcHandler);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment