diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index a6d390e13f396f5949eddaa7c977491b8e3abb93..c95e64e8e2cda0e384030829c5f126f15699a0fd 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -20,14 +20,18 @@ package org.apache.spark.network.server; import java.util.Iterator; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import com.google.common.base.Preconditions; + /** * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually * fetched as chunks by the client. Each registered buffer is one chunk. @@ -36,18 +40,21 @@ public class OneForOneStreamManager extends StreamManager { private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class); private final AtomicLong nextStreamId; - private final Map<Long, StreamState> streams; + private final ConcurrentHashMap<Long, StreamState> streams; /** State of a single stream. */ private static class StreamState { final Iterator<ManagedBuffer> buffers; + // The channel associated to the stream + Channel associatedChannel = null; + // Used to keep track of the index of the buffer that the user has retrieved, just to ensure // that the caller only requests each chunk one at a time, in order. int curChunk = 0; StreamState(Iterator<ManagedBuffer> buffers) { - this.buffers = buffers; + this.buffers = Preconditions.checkNotNull(buffers); } } @@ -58,6 +65,13 @@ public class OneForOneStreamManager extends StreamManager { streams = new ConcurrentHashMap<Long, StreamState>(); } + @Override + public void registerChannel(Channel channel, long streamId) { + if (streams.containsKey(streamId)) { + streams.get(streamId).associatedChannel = channel; + } + } + @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { StreamState state = streams.get(streamId); @@ -80,12 +94,17 @@ public class OneForOneStreamManager extends StreamManager { } @Override - public void connectionTerminated(long streamId) { - // Release all remaining buffers. - StreamState state = streams.remove(streamId); - if (state != null && state.buffers != null) { - while (state.buffers.hasNext()) { - state.buffers.next().release(); + public void connectionTerminated(Channel channel) { + // Close all streams which have been associated with the channel. + for (Map.Entry<Long, StreamState> entry: streams.entrySet()) { + StreamState state = entry.getValue(); + if (state.associatedChannel == channel) { + streams.remove(entry.getKey()); + + // Release all remaining buffers. + while (state.buffers.hasNext()) { + state.buffers.next().release(); + } } } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java index 5a9a14a180c108af946fd31b193819a69f930b0a..929f789bf9d24225894b4626348aafc9baaf3ecc 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -17,6 +17,8 @@ package org.apache.spark.network.server; +import io.netty.channel.Channel; + import org.apache.spark.network.buffer.ManagedBuffer; /** @@ -44,9 +46,18 @@ public abstract class StreamManager { public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); /** - * Indicates that the TCP connection that was tied to the given stream has been terminated. After - * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned - * up. + * Associates a stream with a single client connection, which is guaranteed to be the only reader + * of the stream. The getChunk() method will be called serially on this connection and once the + * connection is closed, the stream will never be used again, enabling cleanup. + * + * This must be called before the first getChunk() on the stream, but it may be invoked multiple + * times with the same channel and stream id. + */ + public void registerChannel(Channel channel, long streamId) { } + + /** + * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not + * to read from the associated streams again, so any state can be cleaned up. */ - public void connectionTerminated(long streamId) { } + public void connectionTerminated(Channel channel) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 1580180cc17e9087bfc2047e864db57ba3271232..e5159ab56d0d4510458bb2988a6d3b02637a3600 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,10 +17,7 @@ package org.apache.spark.network.server; -import java.util.Set; - import com.google.common.base.Throwables; -import com.google.common.collect.Sets; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; @@ -62,9 +59,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { /** Returns each chunk part of a stream. */ private final StreamManager streamManager; - /** List of all stream ids that have been read on this handler, used for cleanup. */ - private final Set<Long> streamIds; - public TransportRequestHandler( Channel channel, TransportClient reverseClient, @@ -73,7 +67,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { this.reverseClient = reverseClient; this.rpcHandler = rpcHandler; this.streamManager = rpcHandler.getStreamManager(); - this.streamIds = Sets.newHashSet(); } @Override @@ -82,10 +75,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { @Override public void channelUnregistered() { - // Inform the StreamManager that these streams will no longer be read from. - for (long streamId : streamIds) { - streamManager.connectionTerminated(streamId); - } + streamManager.connectionTerminated(channel); rpcHandler.connectionTerminated(reverseClient); } @@ -102,12 +92,12 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { private void processFetchRequest(final ChunkFetchRequest req) { final String client = NettyUtils.getRemoteAddress(channel); - streamIds.add(req.streamChunkId.streamId); logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); ManagedBuffer buf; try { + streamManager.registerChannel(channel, req.streamChunkId.streamId); buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); } catch (Exception e) { logger.error(String.format(