diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index fac416a5b3d5d24ca16ab7652c145c2e44302d15..bb78207c9f3c2d924a12e909c524b24870572c5a 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -67,11 +67,20 @@ object BlockFetcherIterator { throw new IllegalArgumentException("BlocksByAddress is null") } - protected var _totalBlocks = blocksByAddress.map(_._2.size).sum - logDebug("Getting " + _totalBlocks + " blocks") + // Total number blocks fetched (local + remote). Also number of FetchResults expected + protected var _numBlocksToFetch = 0 + protected var startTime = System.currentTimeMillis - protected val localBlockIds = new ArrayBuffer[String]() - protected val remoteBlockIds = new HashSet[String]() + + // This represents the number of local blocks, also counting zero-sized blocks + private var numLocal = 0 + // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks + protected val localBlocksToFetch = new ArrayBuffer[String]() + + // This represents the number of remote blocks, also counting zero-sized blocks + private var numRemote = 0 + // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks + protected val remoteBlocksToFetch = new HashSet[String]() // A queue to hold our results. protected val results = new LinkedBlockingQueue[FetchResult] @@ -124,13 +133,15 @@ object BlockFetcherIterator { protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. - val originalTotalBlocks = _totalBlocks val remoteRequests = new ArrayBuffer[FetchRequest] for ((address, blockInfos) <- blocksByAddress) { if (address == blockManagerId) { - localBlockIds ++= blockInfos.map(_._1) + numLocal = blockInfos.size + // Filter out zero-sized blocks + localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) + _numBlocksToFetch += localBlocksToFetch.size } else { - remoteBlockIds ++= blockInfos.map(_._1) + numRemote += blockInfos.size // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. @@ -144,10 +155,10 @@ object BlockFetcherIterator { // Skip empty blocks if (size > 0) { curBlocks += ((blockId, size)) + remoteBlocksToFetch += blockId + _numBlocksToFetch += 1 curRequestSize += size - } else if (size == 0) { - _totalBlocks -= 1 - } else { + } else if (size < 0) { throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= minRequestSize) { @@ -163,8 +174,8 @@ object BlockFetcherIterator { } } } - logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " + - originalTotalBlocks + " blocks") + logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " + + totalBlocks + " blocks") remoteRequests } @@ -172,7 +183,7 @@ object BlockFetcherIterator { // Get the local blocks while remote blocks are being fetched. Note that it's okay to do // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight - for (id <- localBlockIds) { + for (id <- localBlocksToFetch) { getLocalFromDisk(id, serializer) match { case Some(iter) => { // Pass 0 as size since it's not in flight @@ -198,7 +209,7 @@ object BlockFetcherIterator { sendRequest(fetchRequests.dequeue()) } - val numGets = remoteBlockIds.size - fetchRequests.size + val numGets = remoteRequests.size - fetchRequests.size logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) // Get Local Blocks @@ -210,7 +221,7 @@ object BlockFetcherIterator { //an iterator that will read fetched blocks off the queue as they arrive. @volatile protected var resultsGotten = 0 - override def hasNext: Boolean = resultsGotten < _totalBlocks + override def hasNext: Boolean = resultsGotten < _numBlocksToFetch override def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 @@ -227,9 +238,9 @@ object BlockFetcherIterator { } // Implementing BlockFetchTracker trait. - override def totalBlocks: Int = _totalBlocks - override def numLocalBlocks: Int = localBlockIds.size - override def numRemoteBlocks: Int = remoteBlockIds.size + override def totalBlocks: Int = numLocal + numRemote + override def numLocalBlocks: Int = numLocal + override def numRemoteBlocks: Int = numRemote override def remoteFetchTime: Long = _remoteFetchTime override def fetchWaitTime: Long = _fetchWaitTime override def remoteBytesRead: Long = _remoteBytesRead @@ -291,7 +302,7 @@ object BlockFetcherIterator { private var copiers: List[_ <: Thread] = null override def initialize() { - // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks + // Split Local Remote Blocks and set numBlocksToFetch val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order for (request <- Utils.randomize(remoteRequests)) { @@ -313,7 +324,7 @@ object BlockFetcherIterator { val result = results.take() // if all the results has been retrieved, shutdown the copiers // NO need to stop the copiers if we got all the blocks ? - // if (resultsGotten == _totalBlocks && copiers != null) { + // if (resultsGotten == _numBlocksToFetch && copiers != null) { // stopCopiers() // } (result.blockId, if (result.failed) None else Some(result.deserialize())) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index c7281200e7e0086660e9cfeb65d288ce28f7875b..0af6e4a35917fd8551f15ab6c6d47c391824d4b2 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -35,21 +35,25 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private var bs: OutputStream = null private var objOut: SerializationStream = null private var lastValidPosition = 0L + private var initialized = false override def open(): DiskBlockObjectWriter = { val fos = new FileOutputStream(f, true) channel = fos.getChannel() - bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos)) + bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize)) objOut = serializer.newInstance().serializeStream(bs) + initialized = true this } override def close() { - objOut.close() - bs.close() - channel = null - bs = null - objOut = null + if (initialized) { + objOut.close() + bs.close() + channel = null + bs = null + objOut = null + } // Invoke the close callback handler. super.close() } @@ -59,23 +63,33 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // Flush the partial writes, and set valid length to be the length of the entire file. // Return the number of bytes written for this commit. override def commit(): Long = { - // NOTE: Flush the serializer first and then the compressed/buffered output stream - objOut.flush() - bs.flush() - val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos + if (initialized) { + // NOTE: Flush the serializer first and then the compressed/buffered output stream + objOut.flush() + bs.flush() + val prevPos = lastValidPosition + lastValidPosition = channel.position() + lastValidPosition - prevPos + } else { + // lastValidPosition is zero if stream is uninitialized + lastValidPosition + } } override def revertPartialWrites() { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) + if (initialized) { + // Discard current writes. We do this by flushing the outstanding writes and + // truncate the file to the last valid position. + objOut.flush() + bs.flush() + channel.truncate(lastValidPosition) + } } override def write(value: Any) { + if (!initialized) { + open() + } objOut.writeObject(value) } @@ -197,7 +211,11 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { val file = getFile(blockId) if (!allowAppendExisting && file.exists()) { - throw new Exception("File for block " + blockId + " already exists on disk: " + file) + // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task + // was rescheduled on the same machine as the old task ? + logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") + file.delete() + // throw new Exception("File for block " + blockId + " already exists on disk: " + file) } file } diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala index 49eabfb0d21505616be4f72531fe074118b3a53c..44638e0c2d8f9be93af4f0904037db9921c4a473 100644 --- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala @@ -24,7 +24,7 @@ class ShuffleBlockManager(blockManager: BlockManager) { val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) - blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open() + blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) } new ShuffleWriterGroup(mapId, writers) } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index b967016cf726791b543781a9f42cf8c9607aab71..33b02fff801445a678b5ab4dfde1014bc8d59f44 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -367,6 +367,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(nonEmptyBlocks.size <= 4) } + test("zero sized blocks without kryo") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + + // 10 partitions from 4 keys + val NUM_BLOCKS = 10 + val a = sc.parallelize(1 to 4, NUM_BLOCKS) + val b = a.map(x => (x, x*2)) + + // NOTE: The default Java serializer doesn't create zero-sized blocks. + // So, use Kryo + val c = new ShuffledRDD(b, new HashPartitioner(10)) + + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 4) + + val blockSizes = (0 until NUM_BLOCKS).flatMap { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + statuses.map(x => x._2) + } + val nonEmptyBlocks = blockSizes.filter(x => x > 0) + + // We should have at most 4 non-zero sized partitions + assert(nonEmptyBlocks.size <= 4) + } + } object ShuffleSuite {