From 16860327286bc08b4e2283d51b4c8fe024ba5006 Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh <viirya@gmail.com>
Date: Fri, 1 May 2015 11:59:12 -0700
Subject: [PATCH] [SPARK-7183] [NETWORK] Fix memory leak of
 TransportRequestHandler.streamIds

JIRA: https://issues.apache.org/jira/browse/SPARK-7183

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #5743 from viirya/fix_requesthandler_memory_leak and squashes the following commits:

cf2c086 [Liang-Chi Hsieh] For comments.
97e205c [Liang-Chi Hsieh] Remove unused import.
d35f19a [Liang-Chi Hsieh] For comments.
f9a0c37 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into fix_requesthandler_memory_leak
45908b7 [Liang-Chi Hsieh] for style.
17f020f [Liang-Chi Hsieh] Remove unused import.
37a4b6c [Liang-Chi Hsieh] Remove streamIds from TransportRequestHandler.
3b3f38a [Liang-Chi Hsieh] Fix memory leak of TransportRequestHandler.streamIds.
---
 .../server/OneForOneStreamManager.java        | 35 ++++++++++++++-----
 .../spark/network/server/StreamManager.java   | 19 +++++++---
 .../server/TransportRequestHandler.java       | 14 ++------
 3 files changed, 44 insertions(+), 24 deletions(-)

diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index a6d390e13f..c95e64e8e2 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -20,14 +20,18 @@ package org.apache.spark.network.server;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.Random;
+import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicLong;
 
+import io.netty.channel.Channel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.network.buffer.ManagedBuffer;
 
+import com.google.common.base.Preconditions;
+
 /**
  * StreamManager which allows registration of an Iterator&lt;ManagedBuffer&gt;, which are individually
  * fetched as chunks by the client. Each registered buffer is one chunk.
@@ -36,18 +40,21 @@ public class OneForOneStreamManager extends StreamManager {
   private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class);
 
   private final AtomicLong nextStreamId;
-  private final Map<Long, StreamState> streams;
+  private final ConcurrentHashMap<Long, StreamState> streams;
 
   /** State of a single stream. */
   private static class StreamState {
     final Iterator<ManagedBuffer> buffers;
 
+    // The channel associated to the stream
+    Channel associatedChannel = null;
+
     // Used to keep track of the index of the buffer that the user has retrieved, just to ensure
     // that the caller only requests each chunk one at a time, in order.
     int curChunk = 0;
 
     StreamState(Iterator<ManagedBuffer> buffers) {
-      this.buffers = buffers;
+      this.buffers = Preconditions.checkNotNull(buffers);
     }
   }
 
@@ -58,6 +65,13 @@ public class OneForOneStreamManager extends StreamManager {
     streams = new ConcurrentHashMap<Long, StreamState>();
   }
 
+  @Override
+  public void registerChannel(Channel channel, long streamId) {
+    if (streams.containsKey(streamId)) {
+      streams.get(streamId).associatedChannel = channel;
+    }
+  }
+
   @Override
   public ManagedBuffer getChunk(long streamId, int chunkIndex) {
     StreamState state = streams.get(streamId);
@@ -80,12 +94,17 @@ public class OneForOneStreamManager extends StreamManager {
   }
 
   @Override
-  public void connectionTerminated(long streamId) {
-    // Release all remaining buffers.
-    StreamState state = streams.remove(streamId);
-    if (state != null && state.buffers != null) {
-      while (state.buffers.hasNext()) {
-        state.buffers.next().release();
+  public void connectionTerminated(Channel channel) {
+    // Close all streams which have been associated with the channel.
+    for (Map.Entry<Long, StreamState> entry: streams.entrySet()) {
+      StreamState state = entry.getValue();
+      if (state.associatedChannel == channel) {
+        streams.remove(entry.getKey());
+
+        // Release all remaining buffers.
+        while (state.buffers.hasNext()) {
+          state.buffers.next().release();
+        }
       }
     }
   }
diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
index 5a9a14a180..929f789bf9 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.network.server;
 
+import io.netty.channel.Channel;
+
 import org.apache.spark.network.buffer.ManagedBuffer;
 
 /**
@@ -44,9 +46,18 @@ public abstract class StreamManager {
   public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);
 
   /**
-   * Indicates that the TCP connection that was tied to the given stream has been terminated. After
-   * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned
-   * up.
+   * Associates a stream with a single client connection, which is guaranteed to be the only reader
+   * of the stream. The getChunk() method will be called serially on this connection and once the
+   * connection is closed, the stream will never be used again, enabling cleanup.
+   *
+   * This must be called before the first getChunk() on the stream, but it may be invoked multiple
+   * times with the same channel and stream id.
+   */
+  public void registerChannel(Channel channel, long streamId) { }
+
+  /**
+   * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
+   * to read from the associated streams again, so any state can be cleaned up.
    */
-  public void connectionTerminated(long streamId) { }
+  public void connectionTerminated(Channel channel) { }
 }
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 1580180cc1..e5159ab56d 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -17,10 +17,7 @@
 
 package org.apache.spark.network.server;
 
-import java.util.Set;
-
 import com.google.common.base.Throwables;
-import com.google.common.collect.Sets;
 import io.netty.channel.Channel;
 import io.netty.channel.ChannelFuture;
 import io.netty.channel.ChannelFutureListener;
@@ -62,9 +59,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
   /** Returns each chunk part of a stream. */
   private final StreamManager streamManager;
 
-  /** List of all stream ids that have been read on this handler, used for cleanup. */
-  private final Set<Long> streamIds;
-
   public TransportRequestHandler(
       Channel channel,
       TransportClient reverseClient,
@@ -73,7 +67,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
     this.reverseClient = reverseClient;
     this.rpcHandler = rpcHandler;
     this.streamManager = rpcHandler.getStreamManager();
-    this.streamIds = Sets.newHashSet();
   }
 
   @Override
@@ -82,10 +75,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
 
   @Override
   public void channelUnregistered() {
-    // Inform the StreamManager that these streams will no longer be read from.
-    for (long streamId : streamIds) {
-      streamManager.connectionTerminated(streamId);
-    }
+    streamManager.connectionTerminated(channel);
     rpcHandler.connectionTerminated(reverseClient);
   }
 
@@ -102,12 +92,12 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
 
   private void processFetchRequest(final ChunkFetchRequest req) {
     final String client = NettyUtils.getRemoteAddress(channel);
-    streamIds.add(req.streamChunkId.streamId);
 
     logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId);
 
     ManagedBuffer buf;
     try {
+      streamManager.registerChannel(channel, req.streamChunkId.streamId);
       buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
     } catch (Exception e) {
       logger.error(String.format(
-- 
GitLab