diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 8f83668d7902989ec17e86c90b292cb79f7c11e3..b3f8bfe8b1d48627c02a53075a5c5b0ff62df7a7 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -46,5 +46,5 @@ trait BlockDataManager { /** * Release locks acquired by [[putBlockData()]] and [[getBlockData()]]. */ - def releaseLock(blockId: BlockId): Unit + def releaseLock(blockId: BlockId, taskAttemptId: Option[Long]): Unit } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 3db59837fbebd4e0c1daeb52e8f95f4ef9a6417f..7064872ec1c7751d45c679f293d6b5ceaacc4329 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -281,22 +281,27 @@ private[storage] class BlockInfoManager extends Logging { /** * Release a lock on the given block. + * In case a TaskContext is not propagated properly to all child threads for the task, we fail to + * get the TID from TaskContext, so we have to explicitly pass the TID value to release the lock. + * + * See SPARK-18406 for more discussion of this issue. */ - def unlock(blockId: BlockId): Unit = synchronized { - logTrace(s"Task $currentTaskAttemptId releasing lock for $blockId") + def unlock(blockId: BlockId, taskAttemptId: Option[TaskAttemptId] = None): Unit = synchronized { + val taskId = taskAttemptId.getOrElse(currentTaskAttemptId) + logTrace(s"Task $taskId releasing lock for $blockId") val info = get(blockId).getOrElse { throw new IllegalStateException(s"Block $blockId not found") } if (info.writerTask != BlockInfo.NO_WRITER) { info.writerTask = BlockInfo.NO_WRITER - writeLocksByTask.removeBinding(currentTaskAttemptId, blockId) + writeLocksByTask.removeBinding(taskId, blockId) } else { assert(info.readerCount > 0, s"Block $blockId is not locked for reading") info.readerCount -= 1 - val countsForTask = readLocksByTask(currentTaskAttemptId) + val countsForTask = readLocksByTask(taskId) val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1 assert(newPinCountForTask >= 0, - s"Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it") + s"Task $taskId release lock on block $blockId more times than it acquired it") } notifyAll() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 137d24b525155cb64413ee8812c5a258381aa884..1689baa832d52469d7dc8173c040e0da30185719 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -501,6 +501,7 @@ private[spark] class BlockManager( case Some(info) => val level = info.level logDebug(s"Level for block $blockId is $level") + val taskAttemptId = Option(TaskContext.get()).map(_.taskAttemptId()) if (level.useMemory && memoryStore.contains(blockId)) { val iter: Iterator[Any] = if (level.deserialized) { memoryStore.getValues(blockId).get @@ -508,7 +509,12 @@ private[spark] class BlockManager( serializerManager.dataDeserializeStream( blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag) } - val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) + // We need to capture the current taskId in case the iterator completion is triggered + // from a different thread which does not have TaskContext set; see SPARK-18406 for + // discussion. + val ci = CompletionIterator[Any, Iterator[Any]](iter, { + releaseLock(blockId, taskAttemptId) + }) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) } else if (level.useDisk && diskStore.contains(blockId)) { val diskData = diskStore.getBytes(blockId) @@ -525,8 +531,9 @@ private[spark] class BlockManager( serializerManager.dataDeserializeStream(blockId, stream)(info.classTag) } } - val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, - releaseLockAndDispose(blockId, diskData)) + val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, { + releaseLockAndDispose(blockId, diskData, taskAttemptId) + }) Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) } else { handleLocalReadFailure(blockId) @@ -711,10 +718,13 @@ private[spark] class BlockManager( } /** - * Release a lock on the given block. + * Release a lock on the given block with explicit TID. + * The param `taskAttemptId` should be passed in case we can't get the correct TID from + * TaskContext, for example, the input iterator of a cached RDD iterates to the end in a child + * thread. */ - def releaseLock(blockId: BlockId): Unit = { - blockInfoManager.unlock(blockId) + def releaseLock(blockId: BlockId, taskAttemptId: Option[Long] = None): Unit = { + blockInfoManager.unlock(blockId, taskAttemptId) } /** @@ -1467,8 +1477,11 @@ private[spark] class BlockManager( } } - def releaseLockAndDispose(blockId: BlockId, data: BlockData): Unit = { - blockInfoManager.unlock(blockId) + def releaseLockAndDispose( + blockId: BlockId, + data: BlockData, + taskAttemptId: Option[Long] = None): Unit = { + releaseLock(blockId, taskAttemptId) data.dispose() } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ad56715656c85d1e3e5fa39b477becec524549a1..8d06f5468f4f1e50e6ee5ab7a9e7a93109463743 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.mapred.{FileSplit, TextInputFormat} import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDDSuiteUtils._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} class RDDSuite extends SparkFunSuite with SharedSparkContext { var tempDir: File = _ @@ -1082,6 +1082,22 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(totalPartitionCount == 10) } + test("SPARK-18406: race between end-of-task and completion iterator read lock release") { + val rdd = sc.parallelize(1 to 1000, 10) + rdd.cache() + + rdd.mapPartitions { iter => + ThreadUtils.runInNewThread("TestThread") { + // Iterate to the end of the input iterator, to cause the CompletionIterator completion to + // fire outside of the task's main thread. + while (iter.hasNext) { + iter.next() + } + iter + } + }.collect() + } + // NOTE // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests // running after them and if they access sc those tests will fail as sc is already closed, because