diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
index 8f83668d7902989ec17e86c90b292cb79f7c11e3..b3f8bfe8b1d48627c02a53075a5c5b0ff62df7a7 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
@@ -46,5 +46,5 @@ trait BlockDataManager {
   /**
    * Release locks acquired by [[putBlockData()]] and [[getBlockData()]].
    */
-  def releaseLock(blockId: BlockId): Unit
+  def releaseLock(blockId: BlockId, taskAttemptId: Option[Long]): Unit
 }
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
index 3db59837fbebd4e0c1daeb52e8f95f4ef9a6417f..7064872ec1c7751d45c679f293d6b5ceaacc4329 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
@@ -281,22 +281,27 @@ private[storage] class BlockInfoManager extends Logging {
 
   /**
    * Release a lock on the given block.
+   * In case a TaskContext is not propagated properly to all child threads for the task, we fail to
+   * get the TID from TaskContext, so we have to explicitly pass the TID value to release the lock.
+   *
+   * See SPARK-18406 for more discussion of this issue.
    */
-  def unlock(blockId: BlockId): Unit = synchronized {
-    logTrace(s"Task $currentTaskAttemptId releasing lock for $blockId")
+  def unlock(blockId: BlockId, taskAttemptId: Option[TaskAttemptId] = None): Unit = synchronized {
+    val taskId = taskAttemptId.getOrElse(currentTaskAttemptId)
+    logTrace(s"Task $taskId releasing lock for $blockId")
     val info = get(blockId).getOrElse {
       throw new IllegalStateException(s"Block $blockId not found")
     }
     if (info.writerTask != BlockInfo.NO_WRITER) {
       info.writerTask = BlockInfo.NO_WRITER
-      writeLocksByTask.removeBinding(currentTaskAttemptId, blockId)
+      writeLocksByTask.removeBinding(taskId, blockId)
     } else {
       assert(info.readerCount > 0, s"Block $blockId is not locked for reading")
       info.readerCount -= 1
-      val countsForTask = readLocksByTask(currentTaskAttemptId)
+      val countsForTask = readLocksByTask(taskId)
       val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1
       assert(newPinCountForTask >= 0,
-        s"Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it")
+        s"Task $taskId release lock on block $blockId more times than it acquired it")
     }
     notifyAll()
   }
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 137d24b525155cb64413ee8812c5a258381aa884..1689baa832d52469d7dc8173c040e0da30185719 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -501,6 +501,7 @@ private[spark] class BlockManager(
       case Some(info) =>
         val level = info.level
         logDebug(s"Level for block $blockId is $level")
+        val taskAttemptId = Option(TaskContext.get()).map(_.taskAttemptId())
         if (level.useMemory && memoryStore.contains(blockId)) {
           val iter: Iterator[Any] = if (level.deserialized) {
             memoryStore.getValues(blockId).get
@@ -508,7 +509,12 @@ private[spark] class BlockManager(
             serializerManager.dataDeserializeStream(
               blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag)
           }
-          val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId))
+          // We need to capture the current taskId in case the iterator completion is triggered
+          // from a different thread which does not have TaskContext set; see SPARK-18406 for
+          // discussion.
+          val ci = CompletionIterator[Any, Iterator[Any]](iter, {
+            releaseLock(blockId, taskAttemptId)
+          })
           Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
         } else if (level.useDisk && diskStore.contains(blockId)) {
           val diskData = diskStore.getBytes(blockId)
@@ -525,8 +531,9 @@ private[spark] class BlockManager(
               serializerManager.dataDeserializeStream(blockId, stream)(info.classTag)
             }
           }
-          val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn,
-            releaseLockAndDispose(blockId, diskData))
+          val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, {
+            releaseLockAndDispose(blockId, diskData, taskAttemptId)
+          })
           Some(new BlockResult(ci, DataReadMethod.Disk, info.size))
         } else {
           handleLocalReadFailure(blockId)
@@ -711,10 +718,13 @@ private[spark] class BlockManager(
   }
 
   /**
-   * Release a lock on the given block.
+   * Release a lock on the given block with explicit TID.
+   * The param `taskAttemptId` should be passed in case we can't get the correct TID from
+   * TaskContext, for example, the input iterator of a cached RDD iterates to the end in a child
+   * thread.
    */
-  def releaseLock(blockId: BlockId): Unit = {
-    blockInfoManager.unlock(blockId)
+  def releaseLock(blockId: BlockId, taskAttemptId: Option[Long] = None): Unit = {
+    blockInfoManager.unlock(blockId, taskAttemptId)
   }
 
   /**
@@ -1467,8 +1477,11 @@ private[spark] class BlockManager(
     }
   }
 
-  def releaseLockAndDispose(blockId: BlockId, data: BlockData): Unit = {
-    blockInfoManager.unlock(blockId)
+  def releaseLockAndDispose(
+      blockId: BlockId,
+      data: BlockData,
+      taskAttemptId: Option[Long] = None): Unit = {
+    releaseLock(blockId, taskAttemptId)
     data.dispose()
   }
 
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index ad56715656c85d1e3e5fa39b477becec524549a1..8d06f5468f4f1e50e6ee5ab7a9e7a93109463743 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -30,7 +30,7 @@ import org.apache.hadoop.mapred.{FileSplit, TextInputFormat}
 import org.apache.spark._
 import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.rdd.RDDSuiteUtils._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ThreadUtils, Utils}
 
 class RDDSuite extends SparkFunSuite with SharedSparkContext {
   var tempDir: File = _
@@ -1082,6 +1082,22 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
     assert(totalPartitionCount == 10)
   }
 
+  test("SPARK-18406: race between end-of-task and completion iterator read lock release") {
+    val rdd = sc.parallelize(1 to 1000, 10)
+    rdd.cache()
+
+    rdd.mapPartitions { iter =>
+      ThreadUtils.runInNewThread("TestThread") {
+        // Iterate to the end of the input iterator, to cause the CompletionIterator completion to
+        // fire outside of the task's main thread.
+        while (iter.hasNext) {
+          iter.next()
+        }
+        iter
+      }
+    }.collect()
+  }
+
   // NOTE
   // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests
   // running after them and if they access sc those tests will fail as sc is already closed, because