diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 597d46a3d2223f0bca3d0b6ecd95876d2b3c1ee1..9d8e7e9f03aeaf3065fda77b7163f8428a7343a6 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -17,29 +17,29 @@
 
 package org.apache.spark.shuffle.hash
 
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.util.{Failure, Success, Try}
+import java.io.InputStream
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.util.{Failure, Success}
 
 import org.apache.spark._
-import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.FetchFailedException
-import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
-import org.apache.spark.util.CompletionIterator
+import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
+  ShuffleBlockId}
 
 private[hash] object BlockStoreShuffleFetcher extends Logging {
-  def fetch[T](
+  def fetchBlockStreams(
       shuffleId: Int,
       reduceId: Int,
       context: TaskContext,
-      serializer: Serializer)
-    : Iterator[T] =
+      blockManager: BlockManager,
+      mapOutputTracker: MapOutputTracker)
+    : Iterator[(BlockId, InputStream)] =
   {
     logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
-    val blockManager = SparkEnv.get.blockManager
 
     val startTime = System.currentTimeMillis
-    val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
+    val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
     logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
       shuffleId, reduceId, System.currentTimeMillis - startTime))
 
@@ -53,12 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
         (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
     }
 
-    def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
+    val blockFetcherItr = new ShuffleBlockFetcherIterator(
+      context,
+      blockManager.shuffleClient,
+      blockManager,
+      blocksByAddress,
+      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
+
+    // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler
+    blockFetcherItr.map { blockPair =>
       val blockId = blockPair._1
       val blockOption = blockPair._2
       blockOption match {
-        case Success(block) => {
-          block.asInstanceOf[Iterator[T]]
+        case Success(inputStream) => {
+          (blockId, inputStream)
         }
         case Failure(e) => {
           blockId match {
@@ -72,27 +81,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
         }
       }
     }
-
-    val blockFetcherItr = new ShuffleBlockFetcherIterator(
-      context,
-      SparkEnv.get.blockManager.shuffleClient,
-      blockManager,
-      blocksByAddress,
-      serializer,
-      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
-      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
-    val itr = blockFetcherItr.flatMap(unpackBlock)
-
-    val completionIter = CompletionIterator[T, Iterator[T]](itr, {
-      context.taskMetrics.updateShuffleReadMetrics()
-    })
-
-    new InterruptibleIterator[T](context, completionIter) {
-      val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
-      override def next(): T = {
-        readMetrics.incRecordsRead(1)
-        delegate.next()
-      }
-    }
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 41bafabde05b9c289a37fcf3e8c69d8bef4d9823..d5c9880659dd3f19bba86dddc4f78fb6cb67e450 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -17,16 +17,20 @@
 
 package org.apache.spark.shuffle.hash
 
-import org.apache.spark.{InterruptibleIterator, TaskContext}
+import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
+import org.apache.spark.storage.BlockManager
+import org.apache.spark.util.CompletionIterator
 import org.apache.spark.util.collection.ExternalSorter
 
 private[spark] class HashShuffleReader[K, C](
     handle: BaseShuffleHandle[K, _, C],
     startPartition: Int,
     endPartition: Int,
-    context: TaskContext)
+    context: TaskContext,
+    blockManager: BlockManager = SparkEnv.get.blockManager,
+    mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
   extends ShuffleReader[K, C]
 {
   require(endPartition == startPartition + 1,
@@ -36,20 +40,52 @@ private[spark] class HashShuffleReader[K, C](
 
   /** Read the combined key-values for this reduce task */
   override def read(): Iterator[Product2[K, C]] = {
+    val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
+      handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)
+
+    // Wrap the streams for compression based on configuration
+    val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
+      blockManager.wrapForCompression(blockId, inputStream)
+    }
+
     val ser = Serializer.getSerializer(dep.serializer)
-    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
+    val serializerInstance = ser.newInstance()
+
+    // Create a key/value iterator for each stream
+    val recordIter = wrappedStreams.flatMap { wrappedStream =>
+      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
+      // NextIterator. The NextIterator makes sure that close() is called on the
+      // underlying InputStream when all records have been read.
+      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+    }
+
+    // Update the context task metrics for each record read.
+    val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
+    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
+      recordIter.map(record => {
+        readMetrics.incRecordsRead(1)
+        record
+      }),
+      context.taskMetrics().updateShuffleReadMetrics())
+
+    // An interruptible iterator must be used here in order to support task cancellation
+    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
 
     val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
       if (dep.mapSideCombine) {
-        new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
+        // We are reading values that are already combined
+        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
+        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
       } else {
-        new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
+        // We don't know the value type, but also don't care -- the dependency *should*
+        // have made sure its compatible w/ this aggregator, which will convert the value
+        // type to the combined type C
+        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
+        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
       }
     } else {
       require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
-
-      // Convert the Product2s to pairs since this is what downstream RDDs currently expect
-      iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
+      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
     }
 
     // Sort the output if there is a sort ordering defined.
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index d0faab62c9e9eaa2748d1b30c50f0f7c539355b4..e49e39679e940ac991ea6a844957bda479f20f79 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,23 +17,23 @@
 
 package org.apache.spark.storage
 
+import java.io.InputStream
 import java.util.concurrent.LinkedBlockingQueue
 
 import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
 import scala.util.{Failure, Try}
 
 import org.apache.spark.{Logging, TaskContext}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
 import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.serializer.{SerializerInstance, Serializer}
-import org.apache.spark.util.{CompletionIterator, Utils}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
+import org.apache.spark.util.Utils
 
 /**
  * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
  * manager. For remote blocks, it fetches them using the provided BlockTransferService.
  *
- * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a
- * pipelined fashion as they are received.
+ * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks
+ * in a pipelined fashion as they are received.
  *
  * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
  * using too much memory.
@@ -44,7 +44,6 @@ import org.apache.spark.util.{CompletionIterator, Utils}
  * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
  *                        For each block we also require the size (in bytes as a long field) in
  *                        order to throttle the memory usage.
- * @param serializer serializer used to deserialize the data.
  * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
  */
 private[spark]
@@ -53,9 +52,8 @@ final class ShuffleBlockFetcherIterator(
     shuffleClient: ShuffleClient,
     blockManager: BlockManager,
     blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
-    serializer: Serializer,
     maxBytesInFlight: Long)
-  extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging {
+  extends Iterator[(BlockId, Try[InputStream])] with Logging {
 
   import ShuffleBlockFetcherIterator._
 
@@ -83,7 +81,7 @@ final class ShuffleBlockFetcherIterator(
 
   /**
    * A queue to hold our results. This turns the asynchronous model provided by
-   * [[BlockTransferService]] into a synchronous model (iterator).
+   * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator).
    */
   private[this] val results = new LinkedBlockingQueue[FetchResult]
 
@@ -102,9 +100,7 @@ final class ShuffleBlockFetcherIterator(
   /** Current bytes in flight from our requests */
   private[this] var bytesInFlight = 0L
 
-  private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
-
-  private[this] val serializerInstance: SerializerInstance = serializer.newInstance()
+  private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency()
 
   /**
    * Whether the iterator is still active. If isZombie is true, the callback interface will no
@@ -114,17 +110,23 @@ final class ShuffleBlockFetcherIterator(
 
   initialize()
 
-  /**
-   * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
-   */
-  private[this] def cleanup() {
-    isZombie = true
+  // Decrements the buffer reference count.
+  // The currentResult is set to null to prevent releasing the buffer again on cleanup()
+  private[storage] def releaseCurrentResultBuffer(): Unit = {
     // Release the current buffer if necessary
     currentResult match {
       case SuccessFetchResult(_, _, buf) => buf.release()
       case _ =>
     }
+    currentResult = null
+  }
 
+  /**
+   * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
+   */
+  private[this] def cleanup() {
+    isZombie = true
+    releaseCurrentResultBuffer()
     // Release buffers in the results queue
     val iter = results.iterator()
     while (iter.hasNext) {
@@ -272,7 +274,13 @@ final class ShuffleBlockFetcherIterator(
 
   override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
 
-  override def next(): (BlockId, Try[Iterator[Any]]) = {
+  /**
+   * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers
+   * underlying each InputStream will be freed by the cleanup() method registered with the
+   * TaskCompletionListener. However, callers should close() these InputStreams
+   * as soon as they are no longer needed, in order to release memory as early as possible.
+   */
+  override def next(): (BlockId, Try[InputStream]) = {
     numBlocksProcessed += 1
     val startFetchWait = System.currentTimeMillis()
     currentResult = results.take()
@@ -290,22 +298,15 @@ final class ShuffleBlockFetcherIterator(
       sendRequest(fetchRequests.dequeue())
     }
 
-    val iteratorTry: Try[Iterator[Any]] = result match {
+    val iteratorTry: Try[InputStream] = result match {
       case FailureFetchResult(_, e) =>
         Failure(e)
       case SuccessFetchResult(blockId, _, buf) =>
         // There is a chance that createInputStream can fail (e.g. fetching a local file that does
         // not exist, SPARK-4085). In that case, we should propagate the right exception so
         // the scheduler gets a FetchFailedException.
-        Try(buf.createInputStream()).map { is0 =>
-          val is = blockManager.wrapForCompression(blockId, is0)
-          val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
-          CompletionIterator[Any, Iterator[Any]](iter, {
-            // Once the iterator is exhausted, release the buffer and set currentResult to null
-            // so we don't release it again in cleanup.
-            currentResult = null
-            buf.release()
-          })
+        Try(buf.createInputStream()).map { inputStream =>
+          new BufferReleasingInputStream(inputStream, this)
         }
     }
 
@@ -313,6 +314,39 @@ final class ShuffleBlockFetcherIterator(
   }
 }
 
+/**
+ * Helper class that ensures a ManagedBuffer is release upon InputStream.close()
+ */
+private class BufferReleasingInputStream(
+    private val delegate: InputStream,
+    private val iterator: ShuffleBlockFetcherIterator)
+  extends InputStream {
+  private[this] var closed = false
+
+  override def read(): Int = delegate.read()
+
+  override def close(): Unit = {
+    if (!closed) {
+      delegate.close()
+      iterator.releaseCurrentResultBuffer()
+      closed = true
+    }
+  }
+
+  override def available(): Int = delegate.available()
+
+  override def mark(readlimit: Int): Unit = delegate.mark(readlimit)
+
+  override def skip(n: Long): Long = delegate.skip(n)
+
+  override def markSupported(): Boolean = delegate.markSupported()
+
+  override def read(b: Array[Byte]): Int = delegate.read(b)
+
+  override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len)
+
+  override def reset(): Unit = delegate.reset()
+}
 
 private[storage]
 object ShuffleBlockFetcherIterator {
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..28ca68698e3dcafaf17479318ca3c81a9ca69ff7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import java.io.{ByteArrayOutputStream, InputStream}
+import java.nio.ByteBuffer
+
+import org.mockito.Matchers.{eq => meq, _}
+import org.mockito.Mockito.{mock, when}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+
+import org.apache.spark._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.shuffle.BaseShuffleHandle
+import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId}
+
+/**
+ * Wrapper for a managed buffer that keeps track of how many times retain and release are called.
+ *
+ * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class
+ * is final (final classes cannot be spied on).
+ */
+class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer {
+  var callsToRetain = 0
+  var callsToRelease = 0
+
+  override def size(): Long = underlyingBuffer.size()
+  override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer()
+  override def createInputStream(): InputStream = underlyingBuffer.createInputStream()
+  override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty()
+
+  override def retain(): ManagedBuffer = {
+    callsToRetain += 1
+    underlyingBuffer.retain()
+  }
+  override def release(): ManagedBuffer = {
+    callsToRelease += 1
+    underlyingBuffer.release()
+  }
+}
+
+class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {
+
+  /**
+   * This test makes sure that, when data is read from a HashShuffleReader, the underlying
+   * ManagedBuffers that contain the data are eventually released.
+   */
+  test("read() releases resources on completion") {
+    val testConf = new SparkConf(false)
+    // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the
+    // shuffle code calls SparkEnv.get()).
+    sc = new SparkContext("local", "test", testConf)
+
+    val reduceId = 15
+    val shuffleId = 22
+    val numMaps = 6
+    val keyValuePairsPerMap = 10
+    val serializer = new JavaSerializer(testConf)
+
+    // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we
+    // can ensure retain() and release() are properly called.
+    val blockManager = mock(classOf[BlockManager])
+
+    // Create a return function to use for the mocked wrapForCompression method that just returns
+    // the original input stream.
+    val dummyCompressionFunction = new Answer[InputStream] {
+      override def answer(invocation: InvocationOnMock): InputStream =
+        invocation.getArguments()(1).asInstanceOf[InputStream]
+    }
+
+    // Create a buffer with some randomly generated key-value pairs to use as the shuffle data
+    // from each mappers (all mappers return the same shuffle data).
+    val byteOutputStream = new ByteArrayOutputStream()
+    val serializationStream = serializer.newInstance().serializeStream(byteOutputStream)
+    (0 until keyValuePairsPerMap).foreach { i =>
+      serializationStream.writeKey(i)
+      serializationStream.writeValue(2*i)
+    }
+
+    // Setup the mocked BlockManager to return RecordingManagedBuffers.
+    val localBlockManagerId = BlockManagerId("test-client", "test-client", 1)
+    when(blockManager.blockManagerId).thenReturn(localBlockManagerId)
+    val buffers = (0 until numMaps).map { mapId =>
+      // Create a ManagedBuffer with the shuffle data.
+      val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray))
+      val managedBuffer = new RecordingManagedBuffer(nioBuffer)
+
+      // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to
+      // fetch shuffle data.
+      val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
+      when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer)
+      when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream])))
+        .thenAnswer(dummyCompressionFunction)
+
+      managedBuffer
+    }
+
+    // Make a mocked MapOutputTracker for the shuffle reader to use to determine what
+    // shuffle data to read.
+    val mapOutputTracker = mock(classOf[MapOutputTracker])
+    // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks
+    // for the code to read data over the network.
+    val statuses: Array[(BlockManagerId, Long)] =
+      Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size().toLong))
+    when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses)
+
+    // Create a mocked shuffle handle to pass into HashShuffleReader.
+    val shuffleHandle = {
+      val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
+      when(dependency.serializer).thenReturn(Some(serializer))
+      when(dependency.aggregator).thenReturn(None)
+      when(dependency.keyOrdering).thenReturn(None)
+      new BaseShuffleHandle(shuffleId, numMaps, dependency)
+    }
+
+    val shuffleReader = new HashShuffleReader(
+      shuffleHandle,
+      reduceId,
+      reduceId + 1,
+      new TaskContextImpl(0, 0, 0, 0, null),
+      blockManager,
+      mapOutputTracker)
+
+    assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps)
+
+    // Calling .length above will have exhausted the iterator; make sure that exhausting the
+    // iterator caused retain and release to be called on each buffer.
+    buffers.foreach { buffer =>
+      assert(buffer.callsToRetain === 1)
+      assert(buffer.callsToRelease === 1)
+    }
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 2a7fe67ad85855bdddbce6122efb3cc1dbac506e..9ced4148d7206e6dd94fcb74ac8101fb4f5a18ef 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -17,23 +17,25 @@
 
 package org.apache.spark.storage
 
+import java.io.InputStream
 import java.util.concurrent.Semaphore
 
-import scala.concurrent.future
 import scala.concurrent.ExecutionContext.Implicits.global
+import scala.concurrent.future
 
 import org.mockito.Matchers.{any, eq => meq}
 import org.mockito.Mockito._
 import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
+import org.scalatest.PrivateMethodTester
 
-import org.apache.spark.{SparkConf, SparkFunSuite, TaskContextImpl}
+import org.apache.spark.{SparkFunSuite, TaskContextImpl}
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.shuffle.BlockFetchingListener
-import org.apache.spark.serializer.TestSerializer
 
-class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
+
+class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
   // Some of the tests are quite tricky because we are testing the cleanup behavior
   // in the presence of faults.
 
@@ -57,7 +59,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
     transfer
   }
 
-  private val conf = new SparkConf
+  // Create a mock managed buffer for testing
+  def createMockManagedBuffer(): ManagedBuffer = {
+    val mockManagedBuffer = mock(classOf[ManagedBuffer])
+    when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream]))
+    mockManagedBuffer
+  }
 
   test("successful 3 local reads + 2 remote reads") {
     val blockManager = mock(classOf[BlockManager])
@@ -66,9 +73,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
 
     // Make sure blockManager.getBlockData would return the blocks
     val localBlocks = Map[BlockId, ManagedBuffer](
-      ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
-      ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
-      ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]))
+      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer())
     localBlocks.foreach { case (blockId, buf) =>
       doReturn(buf).when(blockManager).getBlockData(meq(blockId))
     }
@@ -76,9 +83,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
     // Make sure remote blocks would return
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
     val remoteBlocks = Map[BlockId, ManagedBuffer](
-      ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]),
-      ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer])
-    )
+      ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer())
 
     val transfer = createMockTransfer(remoteBlocks)
 
@@ -92,7 +98,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
       transfer,
       blockManager,
       blocksByAddress,
-      new TestSerializer,
       48 * 1024 * 1024)
 
     // 3 local blocks fetched in initialization
@@ -100,15 +105,24 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
 
     for (i <- 0 until 5) {
       assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements")
-      val (blockId, subIterator) = iterator.next()
-      assert(subIterator.isSuccess,
+      val (blockId, inputStream) = iterator.next()
+      assert(inputStream.isSuccess,
         s"iterator should have 5 elements defined but actually has $i elements")
 
-      // Make sure we release the buffer once the iterator is exhausted.
+      // Make sure we release buffers when a wrapped input stream is closed.
       val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
+      // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream
+      val wrappedInputStream = inputStream.get.asInstanceOf[BufferReleasingInputStream]
       verify(mockBuf, times(0)).release()
-      subIterator.get.foreach(_ => Unit)  // exhaust the iterator
+      val delegateAccess = PrivateMethod[InputStream]('delegate)
+
+      verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close()
+      wrappedInputStream.close()
+      verify(mockBuf, times(1)).release()
+      verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close()
+      wrappedInputStream.close() // close should be idempotent
       verify(mockBuf, times(1)).release()
+      verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close()
     }
 
     // 3 local blocks, and 2 remote blocks
@@ -125,10 +139,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
     // Make sure remote blocks would return
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
     val blocks = Map[BlockId, ManagedBuffer](
-      ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
-      ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
-      ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])
-    )
+      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer())
 
     // Semaphore to coordinate event sequence in two different threads.
     val sem = new Semaphore(0)
@@ -159,11 +172,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
       transfer,
       blockManager,
       blocksByAddress,
-      new TestSerializer,
       48 * 1024 * 1024)
 
-    // Exhaust the first block, and then it should be released.
-    iterator.next()._2.get.foreach(_ => Unit)
+    verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
+    iterator.next()._2.get.close() // close() first block's input stream
     verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release()
 
     // Get the 2nd block but do not exhaust the iterator
@@ -222,7 +234,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
       transfer,
       blockManager,
       blocksByAddress,
-      new TestSerializer,
       48 * 1024 * 1024)
 
     // Continue only after the mock calls onBlockFetchFailure