Skip to content
Snippets Groups Projects
Commit 16860327 authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Aaron Davidson
Browse files

[SPARK-7183] [NETWORK] Fix memory leak of TransportRequestHandler.streamIds

JIRA: https://issues.apache.org/jira/browse/SPARK-7183

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #5743 from viirya/fix_requesthandler_memory_leak and squashes the following commits:

cf2c086 [Liang-Chi Hsieh] For comments.
97e205c [Liang-Chi Hsieh] Remove unused import.
d35f19a [Liang-Chi Hsieh] For comments.
f9a0c37 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into fix_requesthandler_memory_leak
45908b7 [Liang-Chi Hsieh] for style.
17f020f [Liang-Chi Hsieh] Remove unused import.
37a4b6c [Liang-Chi Hsieh] Remove streamIds from TransportRequestHandler.
3b3f38a [Liang-Chi Hsieh] Fix memory leak of TransportRequestHandler.streamIds.
parent 1262e310
No related branches found
No related tags found
No related merge requests found
...@@ -20,14 +20,18 @@ package org.apache.spark.network.server; ...@@ -20,14 +20,18 @@ package org.apache.spark.network.server;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import io.netty.channel.Channel;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer;
import com.google.common.base.Preconditions;
/** /**
* StreamManager which allows registration of an Iterator&lt;ManagedBuffer&gt;, which are individually * StreamManager which allows registration of an Iterator&lt;ManagedBuffer&gt;, which are individually
* fetched as chunks by the client. Each registered buffer is one chunk. * fetched as chunks by the client. Each registered buffer is one chunk.
...@@ -36,18 +40,21 @@ public class OneForOneStreamManager extends StreamManager { ...@@ -36,18 +40,21 @@ public class OneForOneStreamManager extends StreamManager {
private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class); private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class);
private final AtomicLong nextStreamId; private final AtomicLong nextStreamId;
private final Map<Long, StreamState> streams; private final ConcurrentHashMap<Long, StreamState> streams;
/** State of a single stream. */ /** State of a single stream. */
private static class StreamState { private static class StreamState {
final Iterator<ManagedBuffer> buffers; final Iterator<ManagedBuffer> buffers;
// The channel associated to the stream
Channel associatedChannel = null;
// Used to keep track of the index of the buffer that the user has retrieved, just to ensure // Used to keep track of the index of the buffer that the user has retrieved, just to ensure
// that the caller only requests each chunk one at a time, in order. // that the caller only requests each chunk one at a time, in order.
int curChunk = 0; int curChunk = 0;
StreamState(Iterator<ManagedBuffer> buffers) { StreamState(Iterator<ManagedBuffer> buffers) {
this.buffers = buffers; this.buffers = Preconditions.checkNotNull(buffers);
} }
} }
...@@ -58,6 +65,13 @@ public class OneForOneStreamManager extends StreamManager { ...@@ -58,6 +65,13 @@ public class OneForOneStreamManager extends StreamManager {
streams = new ConcurrentHashMap<Long, StreamState>(); streams = new ConcurrentHashMap<Long, StreamState>();
} }
@Override
public void registerChannel(Channel channel, long streamId) {
if (streams.containsKey(streamId)) {
streams.get(streamId).associatedChannel = channel;
}
}
@Override @Override
public ManagedBuffer getChunk(long streamId, int chunkIndex) { public ManagedBuffer getChunk(long streamId, int chunkIndex) {
StreamState state = streams.get(streamId); StreamState state = streams.get(streamId);
...@@ -80,12 +94,17 @@ public class OneForOneStreamManager extends StreamManager { ...@@ -80,12 +94,17 @@ public class OneForOneStreamManager extends StreamManager {
} }
@Override @Override
public void connectionTerminated(long streamId) { public void connectionTerminated(Channel channel) {
// Release all remaining buffers. // Close all streams which have been associated with the channel.
StreamState state = streams.remove(streamId); for (Map.Entry<Long, StreamState> entry: streams.entrySet()) {
if (state != null && state.buffers != null) { StreamState state = entry.getValue();
while (state.buffers.hasNext()) { if (state.associatedChannel == channel) {
state.buffers.next().release(); streams.remove(entry.getKey());
// Release all remaining buffers.
while (state.buffers.hasNext()) {
state.buffers.next().release();
}
} }
} }
} }
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
package org.apache.spark.network.server; package org.apache.spark.network.server;
import io.netty.channel.Channel;
import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer;
/** /**
...@@ -44,9 +46,18 @@ public abstract class StreamManager { ...@@ -44,9 +46,18 @@ public abstract class StreamManager {
public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);
/** /**
* Indicates that the TCP connection that was tied to the given stream has been terminated. After * Associates a stream with a single client connection, which is guaranteed to be the only reader
* this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned * of the stream. The getChunk() method will be called serially on this connection and once the
* up. * connection is closed, the stream will never be used again, enabling cleanup.
*
* This must be called before the first getChunk() on the stream, but it may be invoked multiple
* times with the same channel and stream id.
*/
public void registerChannel(Channel channel, long streamId) { }
/**
* Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
* to read from the associated streams again, so any state can be cleaned up.
*/ */
public void connectionTerminated(long streamId) { } public void connectionTerminated(Channel channel) { }
} }
...@@ -17,10 +17,7 @@ ...@@ -17,10 +17,7 @@
package org.apache.spark.network.server; package org.apache.spark.network.server;
import java.util.Set;
import com.google.common.base.Throwables; import com.google.common.base.Throwables;
import com.google.common.collect.Sets;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
...@@ -62,9 +59,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { ...@@ -62,9 +59,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
/** Returns each chunk part of a stream. */ /** Returns each chunk part of a stream. */
private final StreamManager streamManager; private final StreamManager streamManager;
/** List of all stream ids that have been read on this handler, used for cleanup. */
private final Set<Long> streamIds;
public TransportRequestHandler( public TransportRequestHandler(
Channel channel, Channel channel,
TransportClient reverseClient, TransportClient reverseClient,
...@@ -73,7 +67,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { ...@@ -73,7 +67,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
this.reverseClient = reverseClient; this.reverseClient = reverseClient;
this.rpcHandler = rpcHandler; this.rpcHandler = rpcHandler;
this.streamManager = rpcHandler.getStreamManager(); this.streamManager = rpcHandler.getStreamManager();
this.streamIds = Sets.newHashSet();
} }
@Override @Override
...@@ -82,10 +75,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { ...@@ -82,10 +75,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
@Override @Override
public void channelUnregistered() { public void channelUnregistered() {
// Inform the StreamManager that these streams will no longer be read from. streamManager.connectionTerminated(channel);
for (long streamId : streamIds) {
streamManager.connectionTerminated(streamId);
}
rpcHandler.connectionTerminated(reverseClient); rpcHandler.connectionTerminated(reverseClient);
} }
...@@ -102,12 +92,12 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { ...@@ -102,12 +92,12 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
private void processFetchRequest(final ChunkFetchRequest req) { private void processFetchRequest(final ChunkFetchRequest req) {
final String client = NettyUtils.getRemoteAddress(channel); final String client = NettyUtils.getRemoteAddress(channel);
streamIds.add(req.streamChunkId.streamId);
logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId);
ManagedBuffer buf; ManagedBuffer buf;
try { try {
streamManager.registerChannel(channel, req.streamChunkId.streamId);
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
} catch (Exception e) { } catch (Exception e) {
logger.error(String.format( logger.error(String.format(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment