diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 35c4dafe9c19ccae485de8fe27f65e8acda0ef87..1ed36bf0692f811aefc2e1ae91c05d37c31c04ff 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -230,6 +230,7 @@ private[spark] object Task {
     dataOut.flush()
     val taskBytes = serializer.serialize(task)
     Utils.writeByteBuffer(taskBytes, out)
+    out.close()
     out.toByteBuffer
   }
 
diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index ec1b0f71492719dd8141b747039e9fbb60d24d6a..205d469f481441b3cd59289bf4a4bee17e3c4e0e 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -33,7 +33,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode}
 import org.apache.spark.serializer.{SerializationStream, SerializerManager}
 import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
 import org.apache.spark.unsafe.Platform
-import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
+import org.apache.spark.util.{SizeEstimator, Utils}
 import org.apache.spark.util.collection.SizeTrackingVector
 import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
 
@@ -277,6 +277,7 @@ private[spark] class MemoryStore(
           "released too much unroll memory")
         Left(new PartiallyUnrolledIterator(
           this,
+          MemoryMode.ON_HEAP,
           unrollMemoryUsedByThisBlock,
           unrolled = arrayValues.toIterator,
           rest = Iterator.empty))
@@ -285,7 +286,11 @@ private[spark] class MemoryStore(
       // We ran out of space while unrolling the values for this block
       logUnrollFailureMessage(blockId, vector.estimateSize())
       Left(new PartiallyUnrolledIterator(
-        this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values))
+        this,
+        MemoryMode.ON_HEAP,
+        unrollMemoryUsedByThisBlock,
+        unrolled = vector.iterator,
+        rest = values))
     }
   }
 
@@ -394,7 +399,7 @@ private[spark] class MemoryStore(
           redirectableStream,
           unrollMemoryUsedByThisBlock,
           memoryMode,
-          bbos.toChunkedByteBuffer,
+          bbos,
           values,
           classTag))
     }
@@ -655,6 +660,7 @@ private[spark] class MemoryStore(
  * The result of a failed [[MemoryStore.putIteratorAsValues()]] call.
  *
  * @param memoryStore  the memoryStore, used for freeing memory.
+ * @param memoryMode   the memory mode (on- or off-heap).
  * @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
  * @param unrolled     an iterator for the partially-unrolled values.
  * @param rest         the rest of the original iterator passed to
@@ -662,13 +668,14 @@ private[spark] class MemoryStore(
  */
 private[storage] class PartiallyUnrolledIterator[T](
     memoryStore: MemoryStore,
+    memoryMode: MemoryMode,
     unrollMemory: Long,
     private[this] var unrolled: Iterator[T],
     rest: Iterator[T])
   extends Iterator[T] {
 
   private def releaseUnrollMemory(): Unit = {
-    memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
+    memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
     // SPARK-17503: Garbage collects the unrolling memory before the life end of
     // PartiallyUnrolledIterator.
     unrolled = null
@@ -706,7 +713,7 @@ private[storage] class PartiallyUnrolledIterator[T](
 /**
  * A wrapper which allows an open [[OutputStream]] to be redirected to a different sink.
  */
-private class RedirectableOutputStream extends OutputStream {
+private[storage] class RedirectableOutputStream extends OutputStream {
   private[this] var os: OutputStream = _
   def setOutputStream(s: OutputStream): Unit = { os = s }
   override def write(b: Int): Unit = os.write(b)
@@ -726,7 +733,8 @@ private class RedirectableOutputStream extends OutputStream {
  * @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
  * @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
  * @param memoryMode whether the unroll memory is on- or off-heap
- * @param unrolled a byte buffer containing the partially-serialized values.
+ * @param bbos byte buffer output stream containing the partially-serialized values.
+ *                     [[redirectableOutputStream]] initially points to this output stream.
  * @param rest         the rest of the original iterator passed to
  *                     [[MemoryStore.putIteratorAsValues()]].
  * @param classTag the [[ClassTag]] for the block.
@@ -735,14 +743,19 @@ private[storage] class PartiallySerializedBlock[T](
     memoryStore: MemoryStore,
     serializerManager: SerializerManager,
     blockId: BlockId,
-    serializationStream: SerializationStream,
-    redirectableOutputStream: RedirectableOutputStream,
-    unrollMemory: Long,
+    private val serializationStream: SerializationStream,
+    private val redirectableOutputStream: RedirectableOutputStream,
+    val unrollMemory: Long,
     memoryMode: MemoryMode,
-    unrolled: ChunkedByteBuffer,
+    bbos: ChunkedByteBufferOutputStream,
     rest: Iterator[T],
     classTag: ClassTag[T]) {
 
+  private lazy val unrolledBuffer: ChunkedByteBuffer = {
+    bbos.close()
+    bbos.toChunkedByteBuffer
+  }
+
   // If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of
   // this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task
   // completion listener here in order to ensure that `unrolled.dispose()` is called at least once.
@@ -751,7 +764,23 @@ private[storage] class PartiallySerializedBlock[T](
     taskContext.addTaskCompletionListener { _ =>
       // When a task completes, its unroll memory will automatically be freed. Thus we do not call
       // releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing.
-      unrolled.dispose()
+      unrolledBuffer.dispose()
+    }
+  }
+
+  // Exposed for testing
+  private[storage] def getUnrolledChunkedByteBuffer: ChunkedByteBuffer = unrolledBuffer
+
+  private[this] var discarded = false
+  private[this] var consumed = false
+
+  private def verifyNotConsumedAndNotDiscarded(): Unit = {
+    if (consumed) {
+      throw new IllegalStateException(
+        "Can only call one of finishWritingToStream() or valuesIterator() and can only call once.")
+    }
+    if (discarded) {
+      throw new IllegalStateException("Cannot call methods on a discarded PartiallySerializedBlock")
     }
   }
 
@@ -759,15 +788,18 @@ private[storage] class PartiallySerializedBlock[T](
    * Called to dispose of this block and free its memory.
    */
   def discard(): Unit = {
-    try {
-      // We want to close the output stream in order to free any resources associated with the
-      // serializer itself (such as Kryo's internal buffers). close() might cause data to be
-      // written, so redirect the output stream to discard that data.
-      redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
-      serializationStream.close()
-    } finally {
-      unrolled.dispose()
-      memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+    if (!discarded) {
+      try {
+        // We want to close the output stream in order to free any resources associated with the
+        // serializer itself (such as Kryo's internal buffers). close() might cause data to be
+        // written, so redirect the output stream to discard that data.
+        redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
+        serializationStream.close()
+      } finally {
+        discarded = true
+        unrolledBuffer.dispose()
+        memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+      }
     }
   }
 
@@ -776,8 +808,10 @@ private[storage] class PartiallySerializedBlock[T](
    * and then serializing the values from the original input iterator.
    */
   def finishWritingToStream(os: OutputStream): Unit = {
+    verifyNotConsumedAndNotDiscarded()
+    consumed = true
     // `unrolled`'s underlying buffers will be freed once this input stream is fully read:
-    ByteStreams.copy(unrolled.toInputStream(dispose = true), os)
+    ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), os)
     memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
     redirectableOutputStream.setOutputStream(os)
     while (rest.hasNext) {
@@ -794,13 +828,22 @@ private[storage] class PartiallySerializedBlock[T](
    * `close()` on it to free its resources.
    */
   def valuesIterator: PartiallyUnrolledIterator[T] = {
+    verifyNotConsumedAndNotDiscarded()
+    consumed = true
+    // Close the serialization stream so that the serializer's internal buffers are freed and any
+    // "end-of-stream" markers can be written out so that `unrolled` is a valid serialized stream.
+    serializationStream.close()
     // `unrolled`'s underlying buffers will be freed once this input stream is fully read:
     val unrolledIter = serializerManager.dataDeserializeStream(
-      blockId, unrolled.toInputStream(dispose = true))(classTag)
+      blockId, unrolledBuffer.toInputStream(dispose = true))(classTag)
+    // The unroll memory will be freed once `unrolledIter` is fully consumed in
+    // PartiallyUnrolledIterator. If the iterator is not consumed by the end of the task then any
+    // extra unroll memory will automatically be freed by a `finally` block in `Task`.
     new PartiallyUnrolledIterator(
       memoryStore,
+      memoryMode,
       unrollMemory,
-      unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()),
+      unrolled = unrolledIter,
       rest = rest)
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
index 09e7579ae96060e6389fd98ab2011187ff039a77..9077b86f9ba1dfede6eec1ffb67fee4fa88c4b4c 100644
--- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
@@ -29,7 +29,32 @@ private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutp
 
   def getCount(): Int = count
 
+  private[this] var closed: Boolean = false
+
+  override def write(b: Int): Unit = {
+    require(!closed, "cannot write to a closed ByteBufferOutputStream")
+    super.write(b)
+  }
+
+  override def write(b: Array[Byte], off: Int, len: Int): Unit = {
+    require(!closed, "cannot write to a closed ByteBufferOutputStream")
+    super.write(b, off, len)
+  }
+
+  override def reset(): Unit = {
+    require(!closed, "cannot reset a closed ByteBufferOutputStream")
+    super.reset()
+  }
+
+  override def close(): Unit = {
+    if (!closed) {
+      super.close()
+      closed = true
+    }
+  }
+
   def toByteBuffer: ByteBuffer = {
-    return ByteBuffer.wrap(buf, 0, count)
+    require(closed, "can only call toByteBuffer() after ByteBufferOutputStream has been closed")
+    ByteBuffer.wrap(buf, 0, count)
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
index 67b50d1e70437ca611e81f2fd6984d466b378289..a625b3289538a4760e412ad98554b15035295bde 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
@@ -49,10 +49,19 @@ private[spark] class ChunkedByteBufferOutputStream(
    */
   private[this] var position = chunkSize
   private[this] var _size = 0
+  private[this] var closed: Boolean = false
 
   def size: Long = _size
 
+  override def close(): Unit = {
+    if (!closed) {
+      super.close()
+      closed = true
+    }
+  }
+
   override def write(b: Int): Unit = {
+    require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
     allocateNewChunkIfNeeded()
     chunks(lastChunkIndex).put(b.toByte)
     position += 1
@@ -60,6 +69,7 @@ private[spark] class ChunkedByteBufferOutputStream(
   }
 
   override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
+    require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
     var written = 0
     while (written < len) {
       allocateNewChunkIfNeeded()
@@ -73,7 +83,6 @@ private[spark] class ChunkedByteBufferOutputStream(
 
   @inline
   private def allocateNewChunkIfNeeded(): Unit = {
-    require(!toChunkedByteBufferWasCalled, "cannot write after toChunkedByteBuffer() is called")
     if (position == chunkSize) {
       chunks += allocator(chunkSize)
       lastChunkIndex += 1
@@ -82,6 +91,7 @@ private[spark] class ChunkedByteBufferOutputStream(
   }
 
   def toChunkedByteBuffer: ChunkedByteBuffer = {
+    require(closed, "cannot call toChunkedByteBuffer() unless close() has been called")
     require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once")
     toChunkedByteBufferWasCalled = true
     if (lastChunkIndex == -1) {
diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
index c11de826677e089502e30cb64cad25bc561bf3cf..9929ea033a99f3996cb8bb057d0c15b71184cc27 100644
--- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
@@ -79,6 +79,13 @@ class MemoryStoreSuite
     (memoryStore, blockInfoManager)
   }
 
+  private def assertSameContents[T](expected: Seq[T], actual: Seq[T], hint: String): Unit = {
+    assert(actual.length === expected.length, s"wrong number of values returned in $hint")
+    expected.iterator.zip(actual.iterator).foreach { case (e, a) =>
+      assert(e === a, s"$hint did not return original values!")
+    }
+  }
+
   test("reserve/release unroll memory") {
     val (memoryStore, _) = makeMemoryStore(12000)
     assert(memoryStore.currentUnrollMemory === 0)
@@ -137,9 +144,7 @@ class MemoryStoreSuite
     var putResult = putIteratorAsValues("unroll", smallList.iterator, ClassTag.Any)
     assert(putResult.isRight)
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
-    smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
-      assert(e === a, "getValues() did not return original values!")
-    }
+    assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues")
     blockInfoManager.lockForWriting("unroll")
     assert(memoryStore.remove("unroll"))
     blockInfoManager.removeBlock("unroll")
@@ -152,9 +157,7 @@ class MemoryStoreSuite
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
     assert(memoryStore.contains("someBlock2"))
     assert(!memoryStore.contains("someBlock1"))
-    smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
-      assert(e === a, "getValues() did not return original values!")
-    }
+    assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues")
     blockInfoManager.lockForWriting("unroll")
     assert(memoryStore.remove("unroll"))
     blockInfoManager.removeBlock("unroll")
@@ -167,9 +170,7 @@ class MemoryStoreSuite
     assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
     assert(!memoryStore.contains("someBlock2"))
     assert(putResult.isLeft)
-    bigList.iterator.zip(putResult.left.get).foreach { case (e, a) =>
-      assert(e === a, "putIterator() did not return original values!")
-    }
+    assertSameContents(bigList, putResult.left.get.toSeq, "putIterator")
     // The unroll memory was freed once the iterator returned by putIterator() was fully traversed.
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
   }
@@ -316,9 +317,8 @@ class MemoryStoreSuite
     assert(res.isLeft)
     assert(memoryStore.currentUnrollMemoryForThisTask > 0)
     val valuesReturnedFromFailedPut = res.left.get.valuesIterator.toSeq // force materialization
-    valuesReturnedFromFailedPut.zip(bigList).foreach { case (e, a) =>
-      assert(e === a, "PartiallySerializedBlock.valuesIterator() did not return original values!")
-    }
+    assertSameContents(
+      bigList, valuesReturnedFromFailedPut, "PartiallySerializedBlock.valuesIterator()")
     // The unroll memory was freed once the iterator was fully traversed.
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
   }
@@ -340,12 +340,10 @@ class MemoryStoreSuite
     res.left.get.finishWritingToStream(bos)
     // The unroll memory was freed once the block was fully written.
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
-    val deserializationStream = serializerManager.dataDeserializeStream[Any](
-      "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any)
-    deserializationStream.zip(bigList.iterator).foreach { case (e, a) =>
-      assert(e === a,
-        "PartiallySerializedBlock.finishWritingtoStream() did not write original values!")
-    }
+    val deserializedValues = serializerManager.dataDeserializeStream[Any](
+      "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any).toSeq
+    assertSameContents(
+      bigList, deserializedValues, "PartiallySerializedBlock.finishWritingToStream()")
   }
 
   test("multiple unrolls by the same thread") {
diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..ec4f2637fadd00f28a089dadb2758294471569cf
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
+
+import org.mockito.Mockito
+import org.mockito.Mockito.atLeastOnce
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
+
+import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl}
+import org.apache.spark.memory.MemoryMode
+import org.apache.spark.serializer.{JavaSerializer, SerializationStream, SerializerManager}
+import org.apache.spark.storage.memory.{MemoryStore, PartiallySerializedBlock, RedirectableOutputStream}
+import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream}
+import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
+
+class PartiallySerializedBlockSuite
+    extends SparkFunSuite
+    with BeforeAndAfterEach
+    with PrivateMethodTester {
+
+  private val blockId = new TestBlockId("test")
+  private val conf = new SparkConf()
+  private val memoryStore = Mockito.mock(classOf[MemoryStore], Mockito.RETURNS_SMART_NULLS)
+  private val serializerManager = new SerializerManager(new JavaSerializer(conf), conf)
+
+  private val getSerializationStream = PrivateMethod[SerializationStream]('serializationStream)
+  private val getRedirectableOutputStream =
+    PrivateMethod[RedirectableOutputStream]('redirectableOutputStream)
+
+  override protected def beforeEach(): Unit = {
+    super.beforeEach()
+    Mockito.reset(memoryStore)
+  }
+
+  private def partiallyUnroll[T: ClassTag](
+      iter: Iterator[T],
+      numItemsToBuffer: Int): PartiallySerializedBlock[T] = {
+
+    val bbos: ChunkedByteBufferOutputStream = {
+      val spy = Mockito.spy(new ChunkedByteBufferOutputStream(128, ByteBuffer.allocate))
+      Mockito.doAnswer(new Answer[ChunkedByteBuffer] {
+        override def answer(invocationOnMock: InvocationOnMock): ChunkedByteBuffer = {
+          Mockito.spy(invocationOnMock.callRealMethod().asInstanceOf[ChunkedByteBuffer])
+        }
+      }).when(spy).toChunkedByteBuffer
+      spy
+    }
+
+    val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance()
+    val redirectableOutputStream = Mockito.spy(new RedirectableOutputStream)
+    redirectableOutputStream.setOutputStream(bbos)
+    val serializationStream = Mockito.spy(serializer.serializeStream(redirectableOutputStream))
+
+    (1 to numItemsToBuffer).foreach { _ =>
+      assert(iter.hasNext)
+      serializationStream.writeObject[T](iter.next())
+    }
+
+    val unrollMemory = bbos.size
+    new PartiallySerializedBlock[T](
+      memoryStore,
+      serializerManager,
+      blockId,
+      serializationStream = serializationStream,
+      redirectableOutputStream,
+      unrollMemory = unrollMemory,
+      memoryMode = MemoryMode.ON_HEAP,
+      bbos,
+      rest = iter,
+      classTag = implicitly[ClassTag[T]])
+  }
+
+  test("valuesIterator() and finishWritingToStream() cannot be called after discard() is called") {
+    val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+    partiallySerializedBlock.discard()
+    intercept[IllegalStateException] {
+      partiallySerializedBlock.finishWritingToStream(null)
+    }
+    intercept[IllegalStateException] {
+      partiallySerializedBlock.valuesIterator
+    }
+  }
+
+  test("discard() can be called more than once") {
+    val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+    partiallySerializedBlock.discard()
+    partiallySerializedBlock.discard()
+  }
+
+  test("cannot call valuesIterator() more than once") {
+    val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+    partiallySerializedBlock.valuesIterator
+    intercept[IllegalStateException] {
+      partiallySerializedBlock.valuesIterator
+    }
+  }
+
+  test("cannot call finishWritingToStream() more than once") {
+    val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+    partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+    intercept[IllegalStateException] {
+      partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+    }
+  }
+
+  test("cannot call finishWritingToStream() after valuesIterator()") {
+    val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+    partiallySerializedBlock.valuesIterator
+    intercept[IllegalStateException] {
+      partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+    }
+  }
+
+  test("cannot call valuesIterator() after finishWritingToStream()") {
+    val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+    partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+    intercept[IllegalStateException] {
+      partiallySerializedBlock.valuesIterator
+    }
+  }
+
+  test("buffers are deallocated in a TaskCompletionListener") {
+    try {
+      TaskContext.setTaskContext(TaskContext.empty())
+      val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+      TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted()
+      Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose()
+      Mockito.verifyNoMoreInteractions(memoryStore)
+    } finally {
+      TaskContext.unset()
+    }
+  }
+
+  private def testUnroll[T: ClassTag](
+      testCaseName: String,
+      items: Seq[T],
+      numItemsToBuffer: Int): Unit = {
+
+    test(s"$testCaseName with discard() and numBuffered = $numItemsToBuffer") {
+      val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer)
+      partiallySerializedBlock.discard()
+
+      Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+        MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+      Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+      Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+      Mockito.verifyNoMoreInteractions(memoryStore)
+      Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose()
+    }
+
+    test(s"$testCaseName with finishWritingToStream() and numBuffered = $numItemsToBuffer") {
+      val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer)
+      val bbos = Mockito.spy(new ByteBufferOutputStream())
+      partiallySerializedBlock.finishWritingToStream(bbos)
+
+      Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+        MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+      Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+      Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+      Mockito.verify(bbos).close()
+      Mockito.verifyNoMoreInteractions(memoryStore)
+      Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose()
+
+      val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance()
+      val deserialized =
+        serializer.deserializeStream(new ByteBufferInputStream(bbos.toByteBuffer)).asIterator.toSeq
+      assert(deserialized === items)
+    }
+
+    test(s"$testCaseName with valuesIterator() and numBuffered = $numItemsToBuffer") {
+      val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer)
+      val valuesIterator = partiallySerializedBlock.valuesIterator
+      Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+      Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+
+      val deserializedItems = valuesIterator.toArray.toSeq
+      Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+        MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+      Mockito.verifyNoMoreInteractions(memoryStore)
+      Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose()
+      assert(deserializedItems === items)
+    }
+  }
+
+  testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 50)
+  testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 0)
+  testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 1000)
+  testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 50)
+  testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 0)
+  testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 1000)
+  testUnroll("empty iterator", Seq.empty[String], numItemsToBuffer = 0)
+}
+
+private case class MyCaseClass(str: String)
diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
index 02c2331dc3946273903ea6cfa19ef89a7f7694bd..4253cc8ca4cd1f511f6470e107c1ccdfa02f544a 100644
--- a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
@@ -33,7 +33,7 @@ class PartiallyUnrolledIteratorSuite extends SparkFunSuite with MockitoSugar {
     val rest = (unrollSize until restSize + unrollSize).iterator
 
     val memoryStore = mock[MemoryStore]
-    val joinIterator = new PartiallyUnrolledIterator(memoryStore, unrollSize, unroll, rest)
+    val joinIterator = new PartiallyUnrolledIterator(memoryStore, ON_HEAP, unrollSize, unroll, rest)
 
     // Firstly iterate over unrolling memory iterator
     (0 until unrollSize).foreach { value =>
diff --git a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
index 226622075a6cc79006e9a6d909a1f1929c0b0586..86961745673c6bb95e610c528f101e5872471851 100644
--- a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
@@ -28,12 +28,14 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
 
   test("empty output") {
     val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate)
+    o.close()
     assert(o.toChunkedByteBuffer.size === 0)
   }
 
   test("write a single byte") {
     val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate)
     o.write(10)
+    o.close()
     val chunkedByteBuffer = o.toChunkedByteBuffer
     assert(chunkedByteBuffer.getChunks().length === 1)
     assert(chunkedByteBuffer.getChunks().head.array().toSeq === Seq(10.toByte))
@@ -43,6 +45,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
     val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
     o.write(new Array[Byte](9))
     o.write(99)
+    o.close()
     val chunkedByteBuffer = o.toChunkedByteBuffer
     assert(chunkedByteBuffer.getChunks().length === 1)
     assert(chunkedByteBuffer.getChunks().head.array()(9) === 99.toByte)
@@ -52,6 +55,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
     val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
     o.write(new Array[Byte](10))
     o.write(99)
+    o.close()
     val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
     assert(arrays.length === 2)
     assert(arrays(1).length === 1)
@@ -63,6 +67,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
     Random.nextBytes(ref)
     val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
     o.write(ref)
+    o.close()
     val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
     assert(arrays.length === 1)
     assert(arrays.head.length === ref.length)
@@ -74,6 +79,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
     Random.nextBytes(ref)
     val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
     o.write(ref)
+    o.close()
     val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
     assert(arrays.length === 1)
     assert(arrays.head.length === ref.length)
@@ -85,6 +91,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
     Random.nextBytes(ref)
     val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
     o.write(ref)
+    o.close()
     val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
     assert(arrays.length === 3)
     assert(arrays(0).length === 10)
@@ -101,6 +108,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
     Random.nextBytes(ref)
     val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
     o.write(ref)
+    o.close()
     val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
     assert(arrays.length === 3)
     assert(arrays(0).length === 10)