From a0bb569a818f6ce66c192a3f5782ff56cf58b1d3 Mon Sep 17 00:00:00 2001 From: Aaron Davidson <aaron@databricks.com> Date: Sun, 3 Nov 2013 20:45:11 -0800 Subject: [PATCH] use OpenHashMap, remove monotonicity requirement, fix failure bug --- .../spark/scheduler/ShuffleMapTask.scala | 4 +- .../spark/storage/ShuffleBlockManager.scala | 56 ++++++------------- .../spark/storage/StoragePerfTester.scala | 2 +- .../collection/PrimitiveKeyOpenHashMap.scala | 5 ++ 4 files changed, 26 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 24d97da6eb..c502f8f91a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -148,6 +148,7 @@ private[spark] class ShuffleMapTask( val blockManager = SparkEnv.get.blockManager var shuffle: ShuffleBlocks = null var buckets: ShuffleWriterGroup = null + var success = false try { // Obtain all the block writers for shuffle blocks. @@ -179,6 +180,7 @@ private[spark] class ShuffleMapTask( shuffleMetrics.shuffleWriteTime = totalTime metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) + success = true new MapStatus(blockManager.blockManagerId, compressedSizes) } catch { case e: Exception => // If there is an exception from running the task, revert the partial writes @@ -191,7 +193,7 @@ private[spark] class ShuffleMapTask( // Release the writers back to the shuffle block manager. if (shuffle != null && buckets != null) { buckets.writers.foreach(_.close()) - shuffle.releaseWriters(buckets) + shuffle.releaseWriters(buckets, success) } // Execute the callbacks on task completion. context.executeOnCompleteCallbacks() diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 57b1a28543..8b202ac112 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -28,7 +28,7 @@ import scala.collection.mutable import org.apache.spark.Logging import org.apache.spark.serializer.Serializer import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap} -import org.apache.spark.util.collection.PrimitiveVector +import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} private[spark] class ShuffleWriterGroup( @@ -41,7 +41,8 @@ trait ShuffleBlocks { /** Get a group of writers for this map task. */ def acquireWriters(mapId: Int): ShuffleWriterGroup - def releaseWriters(group: ShuffleWriterGroup) + /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ + def releaseWriters(group: ShuffleWriterGroup, success: Boolean) } /** @@ -123,12 +124,14 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { new ShuffleWriterGroup(mapId, fileGroup, writers) } - override def releaseWriters(group: ShuffleWriterGroup) { + override def releaseWriters(group: ShuffleWriterGroup, success: Boolean) { if (consolidateShuffleFiles) { val fileGroup = group.fileGroup - fileGroup.addMapper(group.mapId) - for ((writer, shuffleFile) <- group.writers.zip(fileGroup.files)) { - shuffleFile.recordMapOutput(writer.fileSegment().offset) + if (success) { + fileGroup.addMapper(group.mapId) + for ((writer, shuffleFile) <- group.writers.zip(fileGroup.files)) { + shuffleFile.recordMapOutput(writer.fileSegment().offset) + } } recycleFileGroup(shuffleId, fileGroup) } @@ -149,18 +152,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { private def getUnusedFileGroup(shuffleId: Int, mapId: Int, numBuckets: Int): ShuffleFileGroup = { val pool = shuffleToFileGroupPoolMap(shuffleId) - var fileGroup = pool.getUnusedFileGroup() - - // If we reuse a file group, ensure we maintain mapId monotonicity. - // This means we may create extra ShuffleFileGroups if we're trying to run a map task - // that is out-of-order with respect to its mapId (which may happen when failures occur). - val fileGroupsToReturn = mutable.ListBuffer[ShuffleFileGroup]() - while (fileGroup != null && fileGroup.maxMapId >= mapId) { - fileGroupsToReturn += fileGroup - fileGroup = pool.getUnusedFileGroup() - } - pool.returnFileGroups(fileGroupsToReturn) // re-add incompatible file groups - + val fileGroup = pool.getUnusedFileGroup() if (fileGroup == null) { val fileId = pool.getNextFileId() val files = Array.tabulate[ShuffleFile](numBuckets) { bucketId => @@ -187,7 +179,6 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { */ def getBlockLocation(id: ShuffleBlockId): FileSegment = { // Search all files associated with the given reducer. - // This process is O(m log n) for m threads and n mappers. Could be sweetened to "likely" O(m). val filesForReducer = shuffleToReducerToFilesMap(id.shuffleId)(id.reduceId) for (file <- filesForReducer) { val segment = file.getFileSegmentFor(id.mapId) @@ -210,37 +201,24 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { /** * A group of shuffle files, one per reducer. * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. - * Mappers must be added in monotonically increasing order by id for efficiency purposes. */ private[spark] class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[ShuffleFile]) { /** - * Contains the set of mappers that have written to this file group, in the same order as they - * have written to their respective files. + * Stores the absolute index of each mapId in the files of this group. For instance, + * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. */ - private val mapIds = new PrimitiveVector[Int]() + private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() files.foreach(_.setShuffleFileGroup(this)) - /** The maximum map id (i.e., last added map task) in this file group. */ - def maxMapId = if (mapIds.length > 0) mapIds(mapIds.length - 1) else -1 - def apply(bucketId: Int) = files(bucketId) def addMapper(mapId: Int) { - assert(mapId > maxMapId, "Attempted to insert mapId out-of-order") - mapIds += mapId + mapIdToIndex(mapId) = mapIdToIndex.size } - /** - * Uses binary search, giving O(log n) runtime. - * NB: Could be improved to amortized O(1) for usual access pattern, where nodes are accessed - * in order of monotonically increasing mapId. That approach is more fragile in general, however. - */ - def indexOf(mapId: Int): Int = { - val index = util.Arrays.binarySearch(mapIds.getUnderlyingArray, 0, mapIds.length, mapId) - if (index >= 0) index else -1 - } + def indexOf(mapId: Int): Int = mapIdToIndex.getOrElse(mapId, -1) } /** @@ -252,7 +230,7 @@ class ShuffleFile(val file: File) { /** * Consecutive offsets of blocks into the file, ordered by position in the file. * This ordering allows us to compute block lengths by examining the following block offset. - * blockOffsets(i) contains the offset for the mapper in shuffleFileGroup.mapIds(i). + * Note: shuffleFileGroup.indexOf(mapId) returns the index of the mapper into this array. */ private val blockOffsets = new PrimitiveVector[Long]() @@ -284,7 +262,7 @@ class ShuffleFile(val file: File) { file.length() - offset } assert(length >= 0) - return Some(new FileSegment(file, offset, length)) + Some(new FileSegment(file, offset, length)) } else { None } diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala index 7dcadc3805..021f6f6688 100644 --- a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala +++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala @@ -50,7 +50,7 @@ object StoragePerfTester { w.close() } - shuffle.releaseWriters(buckets) + shuffle.releaseWriters(buckets, true) } val start = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala index 4adf9cfb76..a119880884 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala @@ -53,6 +53,11 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest, _values(pos) } + def getOrElse(k: K, elseValue: V): V = { + val pos = _keySet.getPos(k) + if (pos >= 0) _values(pos) else elseValue + } + /** Set the value for a key */ def update(k: K, v: V) { val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK -- GitLab