diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 76d537f8e838a75ec9f6296e143e13011a46408d..fbedfbc4460217b6ffa103b172f88979be656d4a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{InputStream, OutputStream} +import java.io.{File, InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{HashMap, ArrayBuffer} @@ -47,7 +47,7 @@ private[spark] class BlockManager( extends Logging { val shuffleBlockManager = new ShuffleBlockManager(this) - val diskBlockManager = new DiskBlockManager( + val diskBlockManager = new DiskBlockManager(shuffleBlockManager, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] @@ -462,15 +462,11 @@ private[spark] class BlockManager( * This is currently used for writing shuffle files out. Callers should handle error * cases. */ - def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int) + def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) - val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true) val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream) writer.registerCloseEventHandler(() => { - if (shuffleBlockManager.consolidateShuffleFiles) { - diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment()) - } val myInfo = new ShuffleBlockInfo() blockInfo.put(blockId, myInfo) myInfo.markReady(writer.fileSegment().length) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 32d2dd06943a0952f7a6763397cbb81000b17933..e49c191c70a1176e9ac0a25994343f4556bcb8ac 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -78,11 +78,11 @@ abstract class BlockObjectWriter(val blockId: BlockId) { /** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */ class DiskBlockObjectWriter( - blockId: BlockId, - file: File, - serializer: Serializer, - bufferSize: Int, - compressStream: OutputStream => OutputStream) + blockId: BlockId, + file: File, + serializer: Serializer, + bufferSize: Int, + compressStream: OutputStream => OutputStream) extends BlockObjectWriter(blockId) with Logging { @@ -111,8 +111,8 @@ class DiskBlockObjectWriter( private var fos: FileOutputStream = null private var ts: TimeTrackingOutputStream = null private var objOut: SerializationStream = null - private var initialPosition = 0L - private var lastValidPosition = 0L + private val initialPosition = file.length() + private var lastValidPosition = initialPosition private var initialized = false private var _timeWriting = 0L @@ -120,7 +120,6 @@ class DiskBlockObjectWriter( fos = new FileOutputStream(file, true) ts = new TimeTrackingOutputStream(fos) channel = fos.getChannel() - initialPosition = channel.position lastValidPosition = initialPosition bs = compressStream(new FastBufferedOutputStream(ts, bufferSize)) objOut = serializer.newInstance().serializeStream(bs) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index bcb58ad9467e6c8ff6fcf611ec570edaebb5c735..4f9537d1c70fa9ff24769a8443ee312819e94503 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -20,12 +20,11 @@ package org.apache.spark.storage import java.io.File import java.text.SimpleDateFormat import java.util.{Date, Random} -import java.util.concurrent.ConcurrentHashMap import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.network.netty.{PathResolver, ShuffleSender} -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} +import org.apache.spark.util.Utils /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -35,7 +34,7 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH * * @param rootDirs The directories to use for storing block files. Data will be hashed among these. */ -private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver with Logging { +private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String) extends PathResolver with Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt @@ -47,54 +46,25 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) private var shuffleSender : ShuffleSender = null - // Stores only Blocks which have been specifically mapped to segments of files - // (rather than the default, which maps a Block to a whole file). - // This keeps our bookkeeping down, since the file system itself tracks the standalone Blocks. - private val blockToFileSegmentMap = new TimeStampedHashMap[BlockId, FileSegment] - - val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DISK_BLOCK_MANAGER, this.cleanup) - addShutdownHook() - /** - * Creates a logical mapping from the given BlockId to a segment of a file. - * This will cause any accesses of the logical BlockId to be directed to the specified - * physical location. - */ - def mapBlockToFileSegment(blockId: BlockId, fileSegment: FileSegment) { - blockToFileSegmentMap.put(blockId, fileSegment) - } - /** * Returns the phyiscal file segment in which the given BlockId is located. * If the BlockId has been mapped to a specific FileSegment, that will be returned. * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly. */ def getBlockLocation(blockId: BlockId): FileSegment = { - if (blockToFileSegmentMap.internalMap.containsKey(blockId)) { - blockToFileSegmentMap.get(blockId).get - } else { - val file = getFile(blockId.name) - new FileSegment(file, 0, file.length()) + if (blockId.isShuffle) { + val segment = shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]) + if (segment.isDefined) { return segment.get } + // If no special mapping found, assume standard block -> file mapping... } - } - /** - * Simply returns a File to place the given Block into. This does not physically create the file. - * If filename is given, that file will be used. Otherwise, we will use the BlockId to get - * a unique filename. - */ - def createBlockFile(blockId: BlockId, filename: String = "", allowAppending: Boolean): File = { - val actualFilename = if (filename == "") blockId.name else filename - val file = getFile(actualFilename) - if (!allowAppending && file.exists()) { - throw new IllegalStateException( - "Attempted to create file that already exists: " + actualFilename) - } - file + val file = getFile(blockId.name) + new FileSegment(file, 0, file.length()) } - private def getFile(filename: String): File = { + def getFile(filename: String): File = { // Figure out which local directory it hashes to, and which subdirectory in that val hash = Utils.nonNegativeHash(filename) val dirId = hash % localDirs.length @@ -119,6 +89,8 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit new File(subDir, filename) } + def getFile(blockId: BlockId): File = getFile(blockId.name) + private def createLocalDirs(): Array[File] = { logDebug("Creating local directories at root dirs '" + rootDirs + "'") val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") @@ -151,10 +123,6 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit } } - private def cleanup(cleanupTime: Long) { - blockToFileSegmentMap.clearOldValues(cleanupTime) - } - private def addShutdownHook() { localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index a3c496f9e05c517f198510095471ab6623b40d22..5a1e7b44440fdac533ae6256ba61c33d70552b7d 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -44,7 +44,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage val bytes = _bytes.duplicate() logDebug("Attempting to put block " + blockId) val startTime = System.currentTimeMillis - val file = diskManager.createBlockFile(blockId, allowAppending = false) + val file = diskManager.getFile(blockId) val channel = new FileOutputStream(file).getChannel() while (bytes.remaining > 0) { channel.write(bytes) @@ -64,7 +64,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage logDebug("Attempting to write values for block " + blockId) val startTime = System.currentTimeMillis - val file = diskManager.createBlockFile(blockId, allowAppending = false) + val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) blockManager.dataSerializeStream(blockId, outputStream, values.iterator) val length = file.length diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 066e45a12b8c7a8e9784a42eba63c373e1b44378..c61febf830e94922545a2c6fee71e558f3c1fd8d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -17,17 +17,29 @@ package org.apache.spark.storage +import java.io.File +import java.util import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger +import scala.collection.JavaConversions._ +import scala.collection.mutable + +import org.apache.spark.Logging import org.apache.spark.serializer.Serializer +import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, AGodDamnPrimitiveVector, TimeStampedHashMap} private[spark] -class ShuffleWriterGroup(val id: Int, val fileId: Int, val writers: Array[BlockObjectWriter]) +class ShuffleWriterGroup( + val mapId: Int, + val fileGroup: ShuffleFileGroup, + val writers: Array[BlockObjectWriter]) private[spark] trait ShuffleBlocks { + /** Get a group of writers for this map task. */ def acquireWriters(mapId: Int): ShuffleWriterGroup + def releaseWriters(group: ShuffleWriterGroup) } @@ -46,51 +58,219 @@ trait ShuffleBlocks { * time owns a particular fileId, and this id is returned to a pool when the task finishes. */ private[spark] -class ShuffleBlockManager(blockManager: BlockManager) { +class ShuffleBlockManager(blockManager: BlockManager) extends Logging { // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. // TODO: Remove this once the shuffle file consolidation feature is stable. val consolidateShuffleFiles = System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean - var nextFileId = new AtomicInteger(0) - val unusedFileIds = new ConcurrentLinkedQueue[java.lang.Integer]() + /** + * Contains a pool of unused ShuffleFileGroups. + * One group is needed per concurrent thread (mapper) operating on the same shuffle. + */ + private class ShuffleFileGroupPool { + private val nextFileId = new AtomicInteger(0) + private val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() + + def getNextFileId() = nextFileId.getAndIncrement() + def getUnusedFileGroup() = unusedFileGroups.poll() + def returnFileGroup(group: ShuffleFileGroup) = unusedFileGroups.add(group) + def returnFileGroups(groups: Seq[ShuffleFileGroup]) = unusedFileGroups.addAll(groups) + } + + type ShuffleId = Int + private val shuffleToFileGroupPoolMap = new TimeStampedHashMap[ShuffleId, ShuffleFileGroupPool] + + /** + * Maps reducers (of a particular shuffle) to the set of files that have blocks destined for them. + * Each reducer will have one ShuffleFile per concurrent thread that executed during mapping. + */ + private val shuffleToReducerToFilesMap = + new TimeStampedHashMap[ShuffleId, Array[ConcurrentLinkedQueue[ShuffleFile]]] + + private + val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup) def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) = { + initializeShuffleMetadata(shuffleId, numBuckets) + new ShuffleBlocks { - // Get a group of writers for a map task. override def acquireWriters(mapId: Int): ShuffleWriterGroup = { val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 - val fileId = getUnusedFileId() + val fileGroup = getUnusedFileGroup(shuffleId, mapId, numBuckets) val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) if (consolidateShuffleFiles) { - val filename = physicalFileName(shuffleId, bucketId, fileId) - blockManager.getDiskWriter(blockId, filename, serializer, bufferSize) + blockManager.getDiskWriter(blockId, fileGroup(bucketId).file, serializer, bufferSize) } else { - blockManager.getDiskWriter(blockId, blockId.name, serializer, bufferSize) + val blockFile = blockManager.diskBlockManager.getFile(blockId) + blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize) } } - new ShuffleWriterGroup(mapId, fileId, writers) + new ShuffleWriterGroup(mapId, fileGroup, writers) } override def releaseWriters(group: ShuffleWriterGroup) { - recycleFileId(group.fileId) + if (consolidateShuffleFiles) { + val fileGroup = group.fileGroup + fileGroup.addMapper(group.mapId) + for ((writer, shuffleFile) <- group.writers.zip(fileGroup.files)) { + shuffleFile.recordMapOutput(writer.fileSegment().offset) + } + recycleFileGroup(shuffleId, fileGroup) + } + } + } + } + + def initializeShuffleMetadata(shuffleId: Int, numBuckets: Int) { + val prev = shuffleToFileGroupPoolMap.putIfAbsent(shuffleId, new ShuffleFileGroupPool()) + if (prev == None) { + val reducerToFilesMap = new Array[ConcurrentLinkedQueue[ShuffleFile]](numBuckets) + for (reducerId <- 0 until numBuckets) { + reducerToFilesMap(reducerId) = new ConcurrentLinkedQueue[ShuffleFile]() } + shuffleToReducerToFilesMap.put(shuffleId, reducerToFilesMap) } } - private def getUnusedFileId(): Int = { - val fileId = unusedFileIds.poll() - if (fileId == null) nextFileId.getAndIncrement() else fileId + private def getUnusedFileGroup(shuffleId: Int, mapId: Int, numBuckets: Int): ShuffleFileGroup = { + if (!consolidateShuffleFiles) { return null } + + val pool = shuffleToFileGroupPoolMap(shuffleId) + var fileGroup = pool.getUnusedFileGroup() + + // If we reuse a file group, ensure we maintain mapId monotonicity. + val fileGroupsToReturn = mutable.ListBuffer[ShuffleFileGroup]() + while (fileGroup != null && fileGroup.maxMapId >= mapId) { + fileGroupsToReturn += fileGroup + fileGroup = pool.getUnusedFileGroup() + } + pool.returnFileGroups(fileGroupsToReturn) // re-add incompatible file groups + + if (fileGroup == null) { + val fileId = pool.getNextFileId() + val files = Array.tabulate[ShuffleFile](numBuckets) { bucketId => + val filename = physicalFileName(shuffleId, bucketId, fileId) + val file = blockManager.diskBlockManager.getFile(filename) + val shuffleFile = new ShuffleFile(file) + shuffleToReducerToFilesMap(shuffleId)(bucketId).add(shuffleFile) + shuffleFile + } + new ShuffleFileGroup(shuffleId, fileId, files) + } else { + fileGroup + } } - private def recycleFileId(fileId: Int) { + private def recycleFileGroup(shuffleId: Int, fileGroup: ShuffleFileGroup) { + shuffleToFileGroupPoolMap(shuffleId).returnFileGroup(fileGroup) + } + + /** + * Returns the physical file segment in which the given BlockId is located. + * If we have no special mapping, None will be returned. + */ + def getBlockLocation(id: ShuffleBlockId): Option[FileSegment] = { + // Search all files associated with the given reducer. + // This process is O(m log n) for m threads and n mappers. Could be sweetened to "likely" O(m). if (consolidateShuffleFiles) { - unusedFileIds.add(fileId) + val filesForReducer = shuffleToReducerToFilesMap(id.shuffleId)(id.reduceId) + for (file <- filesForReducer) { + val segment = file.getFileSegmentFor(id.mapId) + if (segment != None) { return segment } + } + + logInfo("Failed to find shuffle block: " + id) } + None } private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) } + + private def cleanup(cleanupTime: Long) { + shuffleToFileGroupPoolMap.clearOldValues(cleanupTime) + shuffleToReducerToFilesMap.clearOldValues(cleanupTime) + } +} + +/** + * A group of shuffle files, one per reducer. + * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. + * Mappers must be added in monotonically increasing order by id for efficiency purposes. + */ +private[spark] +class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[ShuffleFile]) { + private val mapIds = new AGodDamnPrimitiveVector[Int]() + + files.foreach(_.setShuffleFileGroup(this)) + + /** The maximum map id (i.e., last added map task) in this file group. */ + def maxMapId = if (mapIds.length > 0) mapIds(mapIds.length - 1) else -1 + + def apply(bucketId: Int) = files(bucketId) + + def addMapper(mapId: Int) { + assert(mapId > maxMapId, "Attempted to insert mapId out-of-order") + mapIds += mapId + } + + /** + * Uses binary search, giving O(log n) runtime. + * NB: Could be improved to amortized O(1) for usual access pattern, where nodes are accessed + * in order of monotonically increasing mapId. That approach is more fragile in general, however. + */ + def indexOf(mapId: Int): Int = { + val index = util.Arrays.binarySearch(mapIds.getUnderlyingArray, 0, mapIds.length, mapId) + if (index >= 0) index else -1 + } +} + +/** + * A single, consolidated shuffle file that may contain many actual blocks. All blocks are destined + * to the same reducer. + */ +private[spark] +class ShuffleFile(val file: File) { + /** + * Consecutive offsets of blocks into the file, ordered by position in the file. + * This ordering allows us to compute block lengths by examining the following block offset. + */ + val blockOffsets = new AGodDamnPrimitiveVector[Long]() + + /** Back pointer to whichever ShuffleFileGroup this file is a part of. */ + private var shuffleFileGroup : ShuffleFileGroup = _ + + // Required due to circular dependency between ShuffleFileGroup and ShuffleFile. + def setShuffleFileGroup(group: ShuffleFileGroup) { + assert(shuffleFileGroup == null) + shuffleFileGroup = group + } + + def recordMapOutput(offset: Long) { + blockOffsets += offset + } + + /** + * Returns the FileSegment associated with the given map task, or + * None if this ShuffleFile does not have an entry for it. + */ + def getFileSegmentFor(mapId: Int): Option[FileSegment] = { + val index = shuffleFileGroup.indexOf(mapId) + if (index >= 0) { + val offset = blockOffsets(index) + val length = + if (index + 1 < blockOffsets.length) { + blockOffsets(index + 1) - offset + } else { + file.length() - offset + } + assert(length >= 0) + return Some(new FileSegment(file, offset, length)) + } else { + None + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 3f963727d98ddd3aca90c8bb327e410dceb6f546..67a7f87a5ca6e40bdb254ebee8c61b6e459c856e 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -59,7 +59,7 @@ object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") { val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, - SHUFFLE_MAP_TASK, BLOCK_MANAGER, DISK_BLOCK_MANAGER, BROADCAST_VARS = Value + SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value type MetadataCleanerType = Value diff --git a/core/src/main/scala/org/apache/spark/util/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/PrimitiveVector.scala new file mode 100644 index 0000000000000000000000000000000000000000..d316601b905b9cff52ca7af437d19c0b01ada80a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/PrimitiveVector.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** Provides a simple, non-threadsafe, array-backed vector that can store primitives. */ +class AGodDamnPrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest] + (initialSize: Int = 64) +{ + private var numElements = 0 + private var array = new Array[V](initialSize) + + def apply(index: Int): V = { + require(index < numElements) + array(index) + } + + def +=(value: V) { + if (numElements == array.length) { resize(array.length * 2) } + array(numElements) = value + numElements += 1 + } + + def length = numElements + + def getUnderlyingArray = array + + /** Resizes the array, dropping elements if the total length decreases. */ + def resize(newLength: Int) { + val newArray = new Array[V](newLength) + array.copyToArray(newArray) + array = newArray + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..12ca920094d9a5cadc41a0fa497041a2fb2a6896 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -0,0 +1,80 @@ +package org.apache.spark.storage + +import org.scalatest.{BeforeAndAfterEach, FunSuite} +import java.io.{FileWriter, File} +import java.nio.file.Files +import scala.collection.mutable + +class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { + + val rootDir0 = Files.createTempDirectory("disk-block-manager-suite-0") + val rootDir1 = Files.createTempDirectory("disk-block-manager-suite-1") + val rootDirs = rootDir0.getFileName + "," + rootDir1.getFileName + println("Created root dirs: " + rootDirs) + + val shuffleBlockManager = new ShuffleBlockManager(null) { + var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]() + override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap.get(id) + } + + var diskBlockManager: DiskBlockManager = _ + + override def beforeEach() { + diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs) + shuffleBlockManager.idToSegmentMap.clear() + } + + test("basic block creation") { + val blockId = new TestBlockId("test") + assertSegmentEquals(blockId, blockId.name, 0, 0) + + val newFile = diskBlockManager.getFile(blockId) + writeToFile(newFile, 10) + assertSegmentEquals(blockId, blockId.name, 0, 10) + + newFile.delete() + } + + test("block appending") { + val blockId = new TestBlockId("test") + val newFile = diskBlockManager.getFile(blockId) + writeToFile(newFile, 15) + assertSegmentEquals(blockId, blockId.name, 0, 15) + val newFile2 = diskBlockManager.getFile(blockId) + assert(newFile === newFile2) + writeToFile(newFile2, 12) + assertSegmentEquals(blockId, blockId.name, 0, 27) + newFile.delete() + } + + test("block remapping") { + val filename = "test" + val blockId0 = new ShuffleBlockId(1, 2, 3) + val newFile = diskBlockManager.getFile(filename) + writeToFile(newFile, 15) + shuffleBlockManager.idToSegmentMap(blockId0) = new FileSegment(newFile, 0, 15) + assertSegmentEquals(blockId0, filename, 0, 15) + + val blockId1 = new ShuffleBlockId(1, 2, 4) + val newFile2 = diskBlockManager.getFile(filename) + writeToFile(newFile2, 12) + shuffleBlockManager.idToSegmentMap(blockId1) = new FileSegment(newFile, 15, 12) + assertSegmentEquals(blockId1, filename, 15, 12) + + assert(newFile === newFile2) + newFile.delete() + } + + def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) { + val segment = diskBlockManager.getBlockLocation(blockId) + assert(segment.file.getName === filename) + assert(segment.offset === offset) + assert(segment.length === length) + } + + def writeToFile(file: File, numBytes: Int) { + val writer = new FileWriter(file, true) + for (i <- 0 until numBytes) writer.write(i) + writer.close() + } +}