diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index d220ab51d115bfcbc24f2deb03bff153a75630a9..1a3bf2bb672c6a0d180031b4846adb6bad32c539 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -663,31 +663,43 @@ private[spark] class MemoryStore(
 private[storage] class PartiallyUnrolledIterator[T](
     memoryStore: MemoryStore,
     unrollMemory: Long,
-    unrolled: Iterator[T],
+    private[this] var unrolled: Iterator[T],
     rest: Iterator[T])
   extends Iterator[T] {
 
-  private[this] var unrolledIteratorIsConsumed: Boolean = false
-  private[this] var iter: Iterator[T] = {
-    val completionIterator = CompletionIterator[T, Iterator[T]](unrolled, {
-      unrolledIteratorIsConsumed = true
-      memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
-    })
-    completionIterator ++ rest
+  private def releaseUnrollMemory(): Unit = {
+    memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
+    // SPARK-17503: Garbage collects the unrolling memory before the life end of
+    // PartiallyUnrolledIterator.
+    unrolled = null
   }
 
-  override def hasNext: Boolean = iter.hasNext
-  override def next(): T = iter.next()
+  override def hasNext: Boolean = {
+    if (unrolled == null) {
+      rest.hasNext
+    } else if (!unrolled.hasNext) {
+      releaseUnrollMemory()
+      rest.hasNext
+    } else {
+      true
+    }
+  }
+
+  override def next(): T = {
+    if (unrolled == null) {
+      rest.next()
+    } else {
+      unrolled.next()
+    }
+  }
 
   /**
    * Called to dispose of this iterator and free its memory.
    */
   def close(): Unit = {
-    if (!unrolledIteratorIsConsumed) {
-      memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
-      unrolledIteratorIsConsumed = true
+    if (unrolled != null) {
+      releaseUnrollMemory()
     }
-    iter = null
   }
 }
 
diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..02c2331dc3946273903ea6cfa19ef89a7f7694bd
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.mockito.Matchers
+import org.mockito.Mockito._
+import org.scalatest.mock.MockitoSugar
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.memory.MemoryMode.ON_HEAP
+import org.apache.spark.storage.memory.{MemoryStore, PartiallyUnrolledIterator}
+
+class PartiallyUnrolledIteratorSuite extends SparkFunSuite with MockitoSugar {
+  test("join two iterators") {
+    val unrollSize = 1000
+    val unroll = (0 until unrollSize).iterator
+    val restSize = 500
+    val rest = (unrollSize until restSize + unrollSize).iterator
+
+    val memoryStore = mock[MemoryStore]
+    val joinIterator = new PartiallyUnrolledIterator(memoryStore, unrollSize, unroll, rest)
+
+    // Firstly iterate over unrolling memory iterator
+    (0 until unrollSize).foreach { value =>
+      assert(joinIterator.hasNext)
+      assert(joinIterator.hasNext)
+      assert(joinIterator.next() == value)
+    }
+
+    joinIterator.hasNext
+    joinIterator.hasNext
+    verify(memoryStore, times(1))
+      .releaseUnrollMemoryForThisTask(Matchers.eq(ON_HEAP), Matchers.eq(unrollSize.toLong))
+
+    // Secondly, iterate over rest iterator
+    (unrollSize until unrollSize + restSize).foreach { value =>
+      assert(joinIterator.hasNext)
+      assert(joinIterator.hasNext)
+      assert(joinIterator.next() == value)
+    }
+
+    joinIterator.close()
+    // MemoryMode.releaseUnrollMemoryForThisTask is called only once
+    verifyNoMoreInteractions(memoryStore)
+  }
+}