From 0e5cc30868bcf933f2980c4cfe29abc3d8fe5887 Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@cs.berkeley.edu> Date: Tue, 7 May 2013 18:18:24 -0700 Subject: [PATCH] Cleaned up BlockManager and BlockFetcherIterator from Shane's PR. --- .../spark/storage/BlockFetchTracker.scala | 12 +- .../spark/storage/BlockFetcherIterator.scala | 167 +++++++++--------- .../scala/spark/storage/BlockManager.scala | 22 ++- 3 files changed, 102 insertions(+), 99 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/spark/storage/BlockFetchTracker.scala index 993aece1f7..0718156b1b 100644 --- a/core/src/main/scala/spark/storage/BlockFetchTracker.scala +++ b/core/src/main/scala/spark/storage/BlockFetchTracker.scala @@ -1,10 +1,10 @@ package spark.storage private[spark] trait BlockFetchTracker { - def totalBlocks : Int - def numLocalBlocks: Int - def numRemoteBlocks: Int - def remoteFetchTime : Long - def fetchWaitTime: Long - def remoteBytesRead : Long + def totalBlocks : Int + def numLocalBlocks: Int + def numRemoteBlocks: Int + def remoteFetchTime : Long + def fetchWaitTime: Long + def remoteBytesRead : Long } diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index 30990d9a38..43f835237c 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -7,27 +7,36 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue +import io.netty.buffer.ByteBuf + import spark.Logging import spark.Utils import spark.SparkException - import spark.network.BufferMessage import spark.network.ConnectionManagerId import spark.network.netty.ShuffleCopier - import spark.serializer.Serializer -import io.netty.buffer.ByteBuf +/** + * A block fetcher iterator interface. There are two implementations: + * + * BasicBlockFetcherIterator: uses a custom-built NIO communication layer. + * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer. + * + * Eventually we would like the two to converge and use a single NIO-based communication layer, + * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores), + * NIO would perform poorly and thus the need for the Netty OIO one. + */ + +private[storage] trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker { - def initialize() - } - +private[storage] object BlockFetcherIterator { // A request to fetch one or more blocks, complete with their sizes @@ -45,8 +54,8 @@ object BlockFetcherIterator { class BasicBlockFetcherIterator( private val blockManager: BlockManager, val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], - serializer: Serializer - ) extends BlockFetcherIterator { + serializer: Serializer) + extends BlockFetcherIterator { import blockManager._ @@ -57,23 +66,24 @@ object BlockFetcherIterator { if (blocksByAddress == null) { throw new IllegalArgumentException("BlocksByAddress is null") } - var totalBlocks = blocksByAddress.map(_._2.size).sum - logDebug("Getting " + totalBlocks + " blocks") - var startTime = System.currentTimeMillis - val localBlockIds = new ArrayBuffer[String]() - val remoteBlockIds = new HashSet[String]() + + protected var _totalBlocks = blocksByAddress.map(_._2.size).sum + logDebug("Getting " + _totalBlocks + " blocks") + protected var startTime = System.currentTimeMillis + protected val localBlockIds = new ArrayBuffer[String]() + protected val remoteBlockIds = new HashSet[String]() // A queue to hold our results. - val results = new LinkedBlockingQueue[FetchResult] + protected val results = new LinkedBlockingQueue[FetchResult] // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that // the number of bytes in flight is limited to maxBytesInFlight - val fetchRequests = new Queue[FetchRequest] + private val fetchRequests = new Queue[FetchRequest] // Current bytes in flight from our requests - var bytesInFlight = 0L + private var bytesInFlight = 0L - def sendRequest(req: FetchRequest) { + protected def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort)) val cmId = new ConnectionManagerId(req.address.host, req.address.port) @@ -111,7 +121,7 @@ object BlockFetcherIterator { } } - def splitLocalRemoteBlocks():ArrayBuffer[FetchRequest] = { + protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] @@ -148,14 +158,15 @@ object BlockFetcherIterator { remoteRequests } - def getLocalBlocks(){ + protected def getLocalBlocks() { // Get the local blocks while remote blocks are being fetched. Note that it's okay to do // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight for (id <- localBlockIds) { getLocal(id) match { case Some(iter) => { - results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight + // Pass 0 as size since it's not in flight + results.put(new FetchResult(id, 0, () => iter)) logDebug("Got local block " + id) } case None => { @@ -165,7 +176,7 @@ object BlockFetcherIterator { } } - override def initialize(){ + override def initialize() { // Split local and remote blocks. val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order @@ -184,15 +195,14 @@ object BlockFetcherIterator { startTime = System.currentTimeMillis getLocalBlocks() logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - } //an iterator that will read fetched blocks off the queue as they arrive. @volatile protected var resultsGotten = 0 - def hasNext: Boolean = resultsGotten < totalBlocks + override def hasNext: Boolean = resultsGotten < _totalBlocks - def next(): (String, Option[Iterator[Any]]) = { + override def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 val startFetchWait = System.currentTimeMillis() val result = results.take() @@ -206,74 +216,73 @@ object BlockFetcherIterator { (result.blockId, if (result.failed) None else Some(result.deserialize())) } - - //methods to profile the block fetching - def numLocalBlocks = localBlockIds.size - def numRemoteBlocks = remoteBlockIds.size - - def remoteFetchTime = _remoteFetchTime - def fetchWaitTime = _fetchWaitTime - - def remoteBytesRead = _remoteBytesRead - + // Implementing BlockFetchTracker trait. + override def totalBlocks: Int = _totalBlocks + override def numLocalBlocks: Int = localBlockIds.size + override def numRemoteBlocks: Int = remoteBlockIds.size + override def remoteFetchTime: Long = _remoteFetchTime + override def fetchWaitTime: Long = _fetchWaitTime + override def remoteBytesRead: Long = _remoteBytesRead } - + // End of BasicBlockFetcherIterator class NettyBlockFetcherIterator( blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], - serializer: Serializer - ) extends BasicBlockFetcherIterator(blockManager,blocksByAddress,serializer) { + serializer: Serializer) + extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { import blockManager._ val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] - def putResult(blockId:String, blockSize:Long, blockData:ByteBuffer, - results : LinkedBlockingQueue[FetchResult]){ - results.put(new FetchResult( - blockId, blockSize, () => dataDeserialize(blockId, blockData, serializer) )) - } - - def startCopiers (numCopiers: Int): List [ _ <: Thread]= { + private def startCopiers(numCopiers: Int): List[_ <: Thread] = { (for ( i <- Range(0,numCopiers) ) yield { - val copier = new Thread { - override def run(){ - try { - while(!isInterrupted && !fetchRequestsSync.isEmpty) { + val copier = new Thread { + override def run(){ + try { + while(!isInterrupted && !fetchRequestsSync.isEmpty) { sendRequest(fetchRequestsSync.take()) - } - } catch { - case x: InterruptedException => logInfo("Copier Interrupted") - //case _ => throw new SparkException("Exception Throw in Shuffle Copier") } - } + } catch { + case x: InterruptedException => logInfo("Copier Interrupted") + //case _ => throw new SparkException("Exception Throw in Shuffle Copier") + } } - copier.start - copier + } + copier.start + copier }).toList } //keep this to interrupt the threads when necessary - def stopCopiers(copiers : List[_ <: Thread]) { + private def stopCopiers() { for (copier <- copiers) { copier.interrupt() } } - override def sendRequest(req: FetchRequest) { + override protected def sendRequest(req: FetchRequest) { + + def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) { + val fetchResult = new FetchResult(blockId, blockSize, + () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) + results.put(fetchResult) + } + logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.memoryBytesToString(req.size), req.address.host)) - val cmId = new ConnectionManagerId(req.address.host, System.getProperty("spark.shuffle.sender.port", "6653").toInt) + val cmId = new ConnectionManagerId( + req.address.host, System.getProperty("spark.shuffle.sender.port", "6653").toInt) val cpier = new ShuffleCopier - cpier.getBlocks(cmId,req.blocks,(blockId:String,blockSize:Long,blockData:ByteBuf) => putResult(blockId,blockSize,blockData.nioBuffer,results)) + cpier.getBlocks(cmId, req.blocks, putResult) logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) } - override def splitLocalRemoteBlocks() : ArrayBuffer[FetchRequest] = { + override protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. - val originalTotalBlocks = totalBlocks; + val originalTotalBlocks = _totalBlocks; val remoteRequests = new ArrayBuffer[FetchRequest] for ((address, blockInfos) <- blocksByAddress) { if (address == blockManagerId) { @@ -293,11 +302,11 @@ object BlockFetcherIterator { if (size > 0) { curBlocks += ((blockId, size)) curRequestSize += size - } else if (size == 0){ + } else if (size == 0) { //here we changes the totalBlocks - totalBlocks -= 1 + _totalBlocks -= 1 } else { - throw new SparkException("Negative block size "+blockId) + throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= minRequestSize) { // Add this FetchRequest @@ -312,13 +321,14 @@ object BlockFetcherIterator { } } } - logInfo("Getting " + totalBlocks + " non 0-byte blocks out of " + originalTotalBlocks + " blocks") + logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " + + originalTotalBlocks + " blocks") remoteRequests } - var copiers : List[_ <: Thread] = null + private var copiers: List[_ <: Thread] = null - override def initialize(){ + override def initialize() { // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order @@ -327,7 +337,8 @@ object BlockFetcherIterator { } copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt) - logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime)) + logInfo("Started " + fetchRequestsSync.size + " remote gets in " + + Utils.getUsedTimeMs(startTime)) // Get Local Blocks startTime = System.currentTimeMillis @@ -338,24 +349,12 @@ object BlockFetcherIterator { override def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 val result = results.take() - // if all the results has been retrieved - // shutdown the copiers - if (resultsGotten == totalBlocks) { - if( copiers != null ) - stopCopiers(copiers) + // if all the results has been retrieved, shutdown the copiers + if (resultsGotten == _totalBlocks && copiers != null) { + stopCopiers() } (result.blockId, if (result.failed) None else Some(result.deserialize())) } } - - def apply(t: String, - blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], - serializer: Serializer): BlockFetcherIterator = { - val iter = if (t == "netty") { new NettyBlockFetcherIterator(blockManager,blocksByAddress, serializer) } - else { new BasicBlockFetcherIterator(blockManager,blocksByAddress, serializer) } - iter.initialize() - iter - } + // End of NettyBlockFetcherIterator } - diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index a189c1a025..e0dec3a8bb 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -23,8 +23,7 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam import sun.nio.ch.DirectBuffer -private[spark] -class BlockManager( +private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, val master: BlockManagerMaster, @@ -494,11 +493,16 @@ class BlockManager( def getMultiple( blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) : BlockFetcherIterator = { - if(System.getProperty("spark.shuffle.use.netty", "false").toBoolean){ - return BlockFetcherIterator("netty",this, blocksByAddress, serializer) - } else { - return BlockFetcherIterator("", this, blocksByAddress, serializer) - } + + val iter = + if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) { + new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) + } else { + new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) + } + + iter.initialize() + iter } def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) @@ -942,8 +946,8 @@ class BlockManager( } } -private[spark] -object BlockManager extends Logging { + +private[spark] object BlockManager extends Logging { val ID_GENERATOR = new IdGenerator -- GitLab