diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index a6d390e13f396f5949eddaa7c977491b8e3abb93..c95e64e8e2cda0e384030829c5f126f15699a0fd 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -20,14 +20,18 @@ package org.apache.spark.network.server;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.Random;
+import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicLong;
 
+import io.netty.channel.Channel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.network.buffer.ManagedBuffer;
 
+import com.google.common.base.Preconditions;
+
 /**
  * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually
  * fetched as chunks by the client. Each registered buffer is one chunk.
@@ -36,18 +40,21 @@ public class OneForOneStreamManager extends StreamManager {
   private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class);
 
   private final AtomicLong nextStreamId;
-  private final Map<Long, StreamState> streams;
+  private final ConcurrentHashMap<Long, StreamState> streams;
 
   /** State of a single stream. */
   private static class StreamState {
     final Iterator<ManagedBuffer> buffers;
 
+    // The channel associated to the stream
+    Channel associatedChannel = null;
+
     // Used to keep track of the index of the buffer that the user has retrieved, just to ensure
     // that the caller only requests each chunk one at a time, in order.
     int curChunk = 0;
 
     StreamState(Iterator<ManagedBuffer> buffers) {
-      this.buffers = buffers;
+      this.buffers = Preconditions.checkNotNull(buffers);
     }
   }
 
@@ -58,6 +65,13 @@ public class OneForOneStreamManager extends StreamManager {
     streams = new ConcurrentHashMap<Long, StreamState>();
   }
 
+  @Override
+  public void registerChannel(Channel channel, long streamId) {
+    if (streams.containsKey(streamId)) {
+      streams.get(streamId).associatedChannel = channel;
+    }
+  }
+
   @Override
   public ManagedBuffer getChunk(long streamId, int chunkIndex) {
     StreamState state = streams.get(streamId);
@@ -80,12 +94,17 @@ public class OneForOneStreamManager extends StreamManager {
   }
 
   @Override
-  public void connectionTerminated(long streamId) {
-    // Release all remaining buffers.
-    StreamState state = streams.remove(streamId);
-    if (state != null && state.buffers != null) {
-      while (state.buffers.hasNext()) {
-        state.buffers.next().release();
+  public void connectionTerminated(Channel channel) {
+    // Close all streams which have been associated with the channel.
+    for (Map.Entry<Long, StreamState> entry: streams.entrySet()) {
+      StreamState state = entry.getValue();
+      if (state.associatedChannel == channel) {
+        streams.remove(entry.getKey());
+
+        // Release all remaining buffers.
+        while (state.buffers.hasNext()) {
+          state.buffers.next().release();
+        }
       }
     }
   }
diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
index 5a9a14a180c108af946fd31b193819a69f930b0a..929f789bf9d24225894b4626348aafc9baaf3ecc 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.network.server;
 
+import io.netty.channel.Channel;
+
 import org.apache.spark.network.buffer.ManagedBuffer;
 
 /**
@@ -44,9 +46,18 @@ public abstract class StreamManager {
   public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);
 
   /**
-   * Indicates that the TCP connection that was tied to the given stream has been terminated. After
-   * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned
-   * up.
+   * Associates a stream with a single client connection, which is guaranteed to be the only reader
+   * of the stream. The getChunk() method will be called serially on this connection and once the
+   * connection is closed, the stream will never be used again, enabling cleanup.
+   *
+   * This must be called before the first getChunk() on the stream, but it may be invoked multiple
+   * times with the same channel and stream id.
+   */
+  public void registerChannel(Channel channel, long streamId) { }
+
+  /**
+   * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
+   * to read from the associated streams again, so any state can be cleaned up.
    */
-  public void connectionTerminated(long streamId) { }
+  public void connectionTerminated(Channel channel) { }
 }
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 1580180cc17e9087bfc2047e864db57ba3271232..e5159ab56d0d4510458bb2988a6d3b02637a3600 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -17,10 +17,7 @@
 
 package org.apache.spark.network.server;
 
-import java.util.Set;
-
 import com.google.common.base.Throwables;
-import com.google.common.collect.Sets;
 import io.netty.channel.Channel;
 import io.netty.channel.ChannelFuture;
 import io.netty.channel.ChannelFutureListener;
@@ -62,9 +59,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
   /** Returns each chunk part of a stream. */
   private final StreamManager streamManager;
 
-  /** List of all stream ids that have been read on this handler, used for cleanup. */
-  private final Set<Long> streamIds;
-
   public TransportRequestHandler(
       Channel channel,
       TransportClient reverseClient,
@@ -73,7 +67,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
     this.reverseClient = reverseClient;
     this.rpcHandler = rpcHandler;
     this.streamManager = rpcHandler.getStreamManager();
-    this.streamIds = Sets.newHashSet();
   }
 
   @Override
@@ -82,10 +75,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
 
   @Override
   public void channelUnregistered() {
-    // Inform the StreamManager that these streams will no longer be read from.
-    for (long streamId : streamIds) {
-      streamManager.connectionTerminated(streamId);
-    }
+    streamManager.connectionTerminated(channel);
     rpcHandler.connectionTerminated(reverseClient);
   }
 
@@ -102,12 +92,12 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
 
   private void processFetchRequest(final ChunkFetchRequest req) {
     final String client = NettyUtils.getRemoteAddress(channel);
-    streamIds.add(req.streamChunkId.streamId);
 
     logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId);
 
     ManagedBuffer buf;
     try {
+      streamManager.registerChannel(channel, req.streamChunkId.streamId);
       buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
     } catch (Exception e) {
       logger.error(String.format(