From 16860327286bc08b4e2283d51b4c8fe024ba5006 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh <viirya@gmail.com> Date: Fri, 1 May 2015 11:59:12 -0700 Subject: [PATCH] [SPARK-7183] [NETWORK] Fix memory leak of TransportRequestHandler.streamIds JIRA: https://issues.apache.org/jira/browse/SPARK-7183 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #5743 from viirya/fix_requesthandler_memory_leak and squashes the following commits: cf2c086 [Liang-Chi Hsieh] For comments. 97e205c [Liang-Chi Hsieh] Remove unused import. d35f19a [Liang-Chi Hsieh] For comments. f9a0c37 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into fix_requesthandler_memory_leak 45908b7 [Liang-Chi Hsieh] for style. 17f020f [Liang-Chi Hsieh] Remove unused import. 37a4b6c [Liang-Chi Hsieh] Remove streamIds from TransportRequestHandler. 3b3f38a [Liang-Chi Hsieh] Fix memory leak of TransportRequestHandler.streamIds. --- .../server/OneForOneStreamManager.java | 35 ++++++++++++++----- .../spark/network/server/StreamManager.java | 19 +++++++--- .../server/TransportRequestHandler.java | 14 ++------ 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index a6d390e13f..c95e64e8e2 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -20,14 +20,18 @@ package org.apache.spark.network.server; import java.util.Iterator; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import com.google.common.base.Preconditions; + /** * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually * fetched as chunks by the client. Each registered buffer is one chunk. @@ -36,18 +40,21 @@ public class OneForOneStreamManager extends StreamManager { private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class); private final AtomicLong nextStreamId; - private final Map<Long, StreamState> streams; + private final ConcurrentHashMap<Long, StreamState> streams; /** State of a single stream. */ private static class StreamState { final Iterator<ManagedBuffer> buffers; + // The channel associated to the stream + Channel associatedChannel = null; + // Used to keep track of the index of the buffer that the user has retrieved, just to ensure // that the caller only requests each chunk one at a time, in order. int curChunk = 0; StreamState(Iterator<ManagedBuffer> buffers) { - this.buffers = buffers; + this.buffers = Preconditions.checkNotNull(buffers); } } @@ -58,6 +65,13 @@ public class OneForOneStreamManager extends StreamManager { streams = new ConcurrentHashMap<Long, StreamState>(); } + @Override + public void registerChannel(Channel channel, long streamId) { + if (streams.containsKey(streamId)) { + streams.get(streamId).associatedChannel = channel; + } + } + @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { StreamState state = streams.get(streamId); @@ -80,12 +94,17 @@ public class OneForOneStreamManager extends StreamManager { } @Override - public void connectionTerminated(long streamId) { - // Release all remaining buffers. - StreamState state = streams.remove(streamId); - if (state != null && state.buffers != null) { - while (state.buffers.hasNext()) { - state.buffers.next().release(); + public void connectionTerminated(Channel channel) { + // Close all streams which have been associated with the channel. + for (Map.Entry<Long, StreamState> entry: streams.entrySet()) { + StreamState state = entry.getValue(); + if (state.associatedChannel == channel) { + streams.remove(entry.getKey()); + + // Release all remaining buffers. + while (state.buffers.hasNext()) { + state.buffers.next().release(); + } } } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java index 5a9a14a180..929f789bf9 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -17,6 +17,8 @@ package org.apache.spark.network.server; +import io.netty.channel.Channel; + import org.apache.spark.network.buffer.ManagedBuffer; /** @@ -44,9 +46,18 @@ public abstract class StreamManager { public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); /** - * Indicates that the TCP connection that was tied to the given stream has been terminated. After - * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned - * up. + * Associates a stream with a single client connection, which is guaranteed to be the only reader + * of the stream. The getChunk() method will be called serially on this connection and once the + * connection is closed, the stream will never be used again, enabling cleanup. + * + * This must be called before the first getChunk() on the stream, but it may be invoked multiple + * times with the same channel and stream id. + */ + public void registerChannel(Channel channel, long streamId) { } + + /** + * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not + * to read from the associated streams again, so any state can be cleaned up. */ - public void connectionTerminated(long streamId) { } + public void connectionTerminated(Channel channel) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 1580180cc1..e5159ab56d 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,10 +17,7 @@ package org.apache.spark.network.server; -import java.util.Set; - import com.google.common.base.Throwables; -import com.google.common.collect.Sets; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; @@ -62,9 +59,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { /** Returns each chunk part of a stream. */ private final StreamManager streamManager; - /** List of all stream ids that have been read on this handler, used for cleanup. */ - private final Set<Long> streamIds; - public TransportRequestHandler( Channel channel, TransportClient reverseClient, @@ -73,7 +67,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { this.reverseClient = reverseClient; this.rpcHandler = rpcHandler; this.streamManager = rpcHandler.getStreamManager(); - this.streamIds = Sets.newHashSet(); } @Override @@ -82,10 +75,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { @Override public void channelUnregistered() { - // Inform the StreamManager that these streams will no longer be read from. - for (long streamId : streamIds) { - streamManager.connectionTerminated(streamId); - } + streamManager.connectionTerminated(channel); rpcHandler.connectionTerminated(reverseClient); } @@ -102,12 +92,12 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { private void processFetchRequest(final ChunkFetchRequest req) { final String client = NettyUtils.getRemoteAddress(channel); - streamIds.add(req.streamChunkId.streamId); logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); ManagedBuffer buf; try { + streamManager.registerChannel(channel, req.streamChunkId.streamId); buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); } catch (Exception e) { logger.error(String.format( -- GitLab