Skip to content
Snippets Groups Projects
Commit 0e5cc308 authored by Reynold Xin's avatar Reynold Xin
Browse files

Cleaned up BlockManager and BlockFetcherIterator from Shane's PR.

parent 8b794851
No related branches found
No related tags found
No related merge requests found
package spark.storage package spark.storage
private[spark] trait BlockFetchTracker { private[spark] trait BlockFetchTracker {
def totalBlocks : Int def totalBlocks : Int
def numLocalBlocks: Int def numLocalBlocks: Int
def numRemoteBlocks: Int def numRemoteBlocks: Int
def remoteFetchTime : Long def remoteFetchTime : Long
def fetchWaitTime: Long def fetchWaitTime: Long
def remoteBytesRead : Long def remoteBytesRead : Long
} }
...@@ -7,27 +7,36 @@ import scala.collection.mutable.ArrayBuffer ...@@ -7,27 +7,36 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashSet import scala.collection.mutable.HashSet
import scala.collection.mutable.Queue import scala.collection.mutable.Queue
import io.netty.buffer.ByteBuf
import spark.Logging import spark.Logging
import spark.Utils import spark.Utils
import spark.SparkException import spark.SparkException
import spark.network.BufferMessage import spark.network.BufferMessage
import spark.network.ConnectionManagerId import spark.network.ConnectionManagerId
import spark.network.netty.ShuffleCopier import spark.network.netty.ShuffleCopier
import spark.serializer.Serializer import spark.serializer.Serializer
import io.netty.buffer.ByteBuf
/**
* A block fetcher iterator interface. There are two implementations:
*
* BasicBlockFetcherIterator: uses a custom-built NIO communication layer.
* NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer.
*
* Eventually we would like the two to converge and use a single NIO-based communication layer,
* but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores),
* NIO would perform poorly and thus the need for the Netty OIO one.
*/
private[storage]
trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])]
with Logging with BlockFetchTracker { with Logging with BlockFetchTracker {
def initialize() def initialize()
} }
private[storage]
object BlockFetcherIterator { object BlockFetcherIterator {
// A request to fetch one or more blocks, complete with their sizes // A request to fetch one or more blocks, complete with their sizes
...@@ -45,8 +54,8 @@ object BlockFetcherIterator { ...@@ -45,8 +54,8 @@ object BlockFetcherIterator {
class BasicBlockFetcherIterator( class BasicBlockFetcherIterator(
private val blockManager: BlockManager, private val blockManager: BlockManager,
val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
serializer: Serializer serializer: Serializer)
) extends BlockFetcherIterator { extends BlockFetcherIterator {
import blockManager._ import blockManager._
...@@ -57,23 +66,24 @@ object BlockFetcherIterator { ...@@ -57,23 +66,24 @@ object BlockFetcherIterator {
if (blocksByAddress == null) { if (blocksByAddress == null) {
throw new IllegalArgumentException("BlocksByAddress is null") throw new IllegalArgumentException("BlocksByAddress is null")
} }
var totalBlocks = blocksByAddress.map(_._2.size).sum
logDebug("Getting " + totalBlocks + " blocks") protected var _totalBlocks = blocksByAddress.map(_._2.size).sum
var startTime = System.currentTimeMillis logDebug("Getting " + _totalBlocks + " blocks")
val localBlockIds = new ArrayBuffer[String]() protected var startTime = System.currentTimeMillis
val remoteBlockIds = new HashSet[String]() protected val localBlockIds = new ArrayBuffer[String]()
protected val remoteBlockIds = new HashSet[String]()
// A queue to hold our results. // A queue to hold our results.
val results = new LinkedBlockingQueue[FetchResult] protected val results = new LinkedBlockingQueue[FetchResult]
// Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
// the number of bytes in flight is limited to maxBytesInFlight // the number of bytes in flight is limited to maxBytesInFlight
val fetchRequests = new Queue[FetchRequest] private val fetchRequests = new Queue[FetchRequest]
// Current bytes in flight from our requests // Current bytes in flight from our requests
var bytesInFlight = 0L private var bytesInFlight = 0L
def sendRequest(req: FetchRequest) { protected def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format( logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort)) req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort))
val cmId = new ConnectionManagerId(req.address.host, req.address.port) val cmId = new ConnectionManagerId(req.address.host, req.address.port)
...@@ -111,7 +121,7 @@ object BlockFetcherIterator { ...@@ -111,7 +121,7 @@ object BlockFetcherIterator {
} }
} }
def splitLocalRemoteBlocks():ArrayBuffer[FetchRequest] = { protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight. // at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest] val remoteRequests = new ArrayBuffer[FetchRequest]
...@@ -148,14 +158,15 @@ object BlockFetcherIterator { ...@@ -148,14 +158,15 @@ object BlockFetcherIterator {
remoteRequests remoteRequests
} }
def getLocalBlocks(){ protected def getLocalBlocks() {
// Get the local blocks while remote blocks are being fetched. Note that it's okay to do // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
// these all at once because they will just memory-map some files, so they won't consume // these all at once because they will just memory-map some files, so they won't consume
// any memory that might exceed our maxBytesInFlight // any memory that might exceed our maxBytesInFlight
for (id <- localBlockIds) { for (id <- localBlockIds) {
getLocal(id) match { getLocal(id) match {
case Some(iter) => { case Some(iter) => {
results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight // Pass 0 as size since it's not in flight
results.put(new FetchResult(id, 0, () => iter))
logDebug("Got local block " + id) logDebug("Got local block " + id)
} }
case None => { case None => {
...@@ -165,7 +176,7 @@ object BlockFetcherIterator { ...@@ -165,7 +176,7 @@ object BlockFetcherIterator {
} }
} }
override def initialize(){ override def initialize() {
// Split local and remote blocks. // Split local and remote blocks.
val remoteRequests = splitLocalRemoteBlocks() val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order // Add the remote requests into our queue in a random order
...@@ -184,15 +195,14 @@ object BlockFetcherIterator { ...@@ -184,15 +195,14 @@ object BlockFetcherIterator {
startTime = System.currentTimeMillis startTime = System.currentTimeMillis
getLocalBlocks() getLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
} }
//an iterator that will read fetched blocks off the queue as they arrive. //an iterator that will read fetched blocks off the queue as they arrive.
@volatile protected var resultsGotten = 0 @volatile protected var resultsGotten = 0
def hasNext: Boolean = resultsGotten < totalBlocks override def hasNext: Boolean = resultsGotten < _totalBlocks
def next(): (String, Option[Iterator[Any]]) = { override def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1 resultsGotten += 1
val startFetchWait = System.currentTimeMillis() val startFetchWait = System.currentTimeMillis()
val result = results.take() val result = results.take()
...@@ -206,74 +216,73 @@ object BlockFetcherIterator { ...@@ -206,74 +216,73 @@ object BlockFetcherIterator {
(result.blockId, if (result.failed) None else Some(result.deserialize())) (result.blockId, if (result.failed) None else Some(result.deserialize()))
} }
// Implementing BlockFetchTracker trait.
//methods to profile the block fetching override def totalBlocks: Int = _totalBlocks
def numLocalBlocks = localBlockIds.size override def numLocalBlocks: Int = localBlockIds.size
def numRemoteBlocks = remoteBlockIds.size override def numRemoteBlocks: Int = remoteBlockIds.size
override def remoteFetchTime: Long = _remoteFetchTime
def remoteFetchTime = _remoteFetchTime override def fetchWaitTime: Long = _fetchWaitTime
def fetchWaitTime = _fetchWaitTime override def remoteBytesRead: Long = _remoteBytesRead
def remoteBytesRead = _remoteBytesRead
} }
// End of BasicBlockFetcherIterator
class NettyBlockFetcherIterator( class NettyBlockFetcherIterator(
blockManager: BlockManager, blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
serializer: Serializer serializer: Serializer)
) extends BasicBlockFetcherIterator(blockManager,blocksByAddress,serializer) { extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
import blockManager._ import blockManager._
val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest]
def putResult(blockId:String, blockSize:Long, blockData:ByteBuffer, private def startCopiers(numCopiers: Int): List[_ <: Thread] = {
results : LinkedBlockingQueue[FetchResult]){
results.put(new FetchResult(
blockId, blockSize, () => dataDeserialize(blockId, blockData, serializer) ))
}
def startCopiers (numCopiers: Int): List [ _ <: Thread]= {
(for ( i <- Range(0,numCopiers) ) yield { (for ( i <- Range(0,numCopiers) ) yield {
val copier = new Thread { val copier = new Thread {
override def run(){ override def run(){
try { try {
while(!isInterrupted && !fetchRequestsSync.isEmpty) { while(!isInterrupted && !fetchRequestsSync.isEmpty) {
sendRequest(fetchRequestsSync.take()) sendRequest(fetchRequestsSync.take())
}
} catch {
case x: InterruptedException => logInfo("Copier Interrupted")
//case _ => throw new SparkException("Exception Throw in Shuffle Copier")
} }
} } catch {
case x: InterruptedException => logInfo("Copier Interrupted")
//case _ => throw new SparkException("Exception Throw in Shuffle Copier")
}
} }
copier.start }
copier copier.start
copier
}).toList }).toList
} }
//keep this to interrupt the threads when necessary //keep this to interrupt the threads when necessary
def stopCopiers(copiers : List[_ <: Thread]) { private def stopCopiers() {
for (copier <- copiers) { for (copier <- copiers) {
copier.interrupt() copier.interrupt()
} }
} }
override def sendRequest(req: FetchRequest) { override protected def sendRequest(req: FetchRequest) {
def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) {
val fetchResult = new FetchResult(blockId, blockSize,
() => dataDeserialize(blockId, blockData.nioBuffer, serializer))
results.put(fetchResult)
}
logDebug("Sending request for %d blocks (%s) from %s".format( logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.memoryBytesToString(req.size), req.address.host)) req.blocks.size, Utils.memoryBytesToString(req.size), req.address.host))
val cmId = new ConnectionManagerId(req.address.host, System.getProperty("spark.shuffle.sender.port", "6653").toInt) val cmId = new ConnectionManagerId(
req.address.host, System.getProperty("spark.shuffle.sender.port", "6653").toInt)
val cpier = new ShuffleCopier val cpier = new ShuffleCopier
cpier.getBlocks(cmId,req.blocks,(blockId:String,blockSize:Long,blockData:ByteBuf) => putResult(blockId,blockSize,blockData.nioBuffer,results)) cpier.getBlocks(cmId, req.blocks, putResult)
logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
} }
override def splitLocalRemoteBlocks() : ArrayBuffer[FetchRequest] = { override protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight. // at most maxBytesInFlight in order to limit the amount of data in flight.
val originalTotalBlocks = totalBlocks; val originalTotalBlocks = _totalBlocks;
val remoteRequests = new ArrayBuffer[FetchRequest] val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) { for ((address, blockInfos) <- blocksByAddress) {
if (address == blockManagerId) { if (address == blockManagerId) {
...@@ -293,11 +302,11 @@ object BlockFetcherIterator { ...@@ -293,11 +302,11 @@ object BlockFetcherIterator {
if (size > 0) { if (size > 0) {
curBlocks += ((blockId, size)) curBlocks += ((blockId, size))
curRequestSize += size curRequestSize += size
} else if (size == 0){ } else if (size == 0) {
//here we changes the totalBlocks //here we changes the totalBlocks
totalBlocks -= 1 _totalBlocks -= 1
} else { } else {
throw new SparkException("Negative block size "+blockId) throw new BlockException(blockId, "Negative block size " + size)
} }
if (curRequestSize >= minRequestSize) { if (curRequestSize >= minRequestSize) {
// Add this FetchRequest // Add this FetchRequest
...@@ -312,13 +321,14 @@ object BlockFetcherIterator { ...@@ -312,13 +321,14 @@ object BlockFetcherIterator {
} }
} }
} }
logInfo("Getting " + totalBlocks + " non 0-byte blocks out of " + originalTotalBlocks + " blocks") logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
originalTotalBlocks + " blocks")
remoteRequests remoteRequests
} }
var copiers : List[_ <: Thread] = null private var copiers: List[_ <: Thread] = null
override def initialize(){ override def initialize() {
// Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks
val remoteRequests = splitLocalRemoteBlocks() val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order // Add the remote requests into our queue in a random order
...@@ -327,7 +337,8 @@ object BlockFetcherIterator { ...@@ -327,7 +337,8 @@ object BlockFetcherIterator {
} }
copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt) copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt)
logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime)) logInfo("Started " + fetchRequestsSync.size + " remote gets in " +
Utils.getUsedTimeMs(startTime))
// Get Local Blocks // Get Local Blocks
startTime = System.currentTimeMillis startTime = System.currentTimeMillis
...@@ -338,24 +349,12 @@ object BlockFetcherIterator { ...@@ -338,24 +349,12 @@ object BlockFetcherIterator {
override def next(): (String, Option[Iterator[Any]]) = { override def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1 resultsGotten += 1
val result = results.take() val result = results.take()
// if all the results has been retrieved // if all the results has been retrieved, shutdown the copiers
// shutdown the copiers if (resultsGotten == _totalBlocks && copiers != null) {
if (resultsGotten == totalBlocks) { stopCopiers()
if( copiers != null )
stopCopiers(copiers)
} }
(result.blockId, if (result.failed) None else Some(result.deserialize())) (result.blockId, if (result.failed) None else Some(result.deserialize()))
} }
} }
// End of NettyBlockFetcherIterator
def apply(t: String,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
serializer: Serializer): BlockFetcherIterator = {
val iter = if (t == "netty") { new NettyBlockFetcherIterator(blockManager,blocksByAddress, serializer) }
else { new BasicBlockFetcherIterator(blockManager,blocksByAddress, serializer) }
iter.initialize()
iter
}
} }
...@@ -23,8 +23,7 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam ...@@ -23,8 +23,7 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam
import sun.nio.ch.DirectBuffer import sun.nio.ch.DirectBuffer
private[spark] private[spark] class BlockManager(
class BlockManager(
executorId: String, executorId: String,
actorSystem: ActorSystem, actorSystem: ActorSystem,
val master: BlockManagerMaster, val master: BlockManagerMaster,
...@@ -494,11 +493,16 @@ class BlockManager( ...@@ -494,11 +493,16 @@ class BlockManager(
def getMultiple( def getMultiple(
blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
: BlockFetcherIterator = { : BlockFetcherIterator = {
if(System.getProperty("spark.shuffle.use.netty", "false").toBoolean){
return BlockFetcherIterator("netty",this, blocksByAddress, serializer) val iter =
} else { if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) {
return BlockFetcherIterator("", this, blocksByAddress, serializer) new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
} } else {
new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
}
iter.initialize()
iter
} }
def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
...@@ -942,8 +946,8 @@ class BlockManager( ...@@ -942,8 +946,8 @@ class BlockManager(
} }
} }
private[spark]
object BlockManager extends Logging { private[spark] object BlockManager extends Logging {
val ID_GENERATOR = new IdGenerator val ID_GENERATOR = new IdGenerator
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment