Skip to content
Snippets Groups Projects
Commit 0e5cc308 authored by Reynold Xin's avatar Reynold Xin
Browse files

Cleaned up BlockManager and BlockFetcherIterator from Shane's PR.

parent 8b794851
No related branches found
No related tags found
No related merge requests found
package spark.storage
private[spark] trait BlockFetchTracker {
def totalBlocks : Int
def numLocalBlocks: Int
def numRemoteBlocks: Int
def remoteFetchTime : Long
def fetchWaitTime: Long
def remoteBytesRead : Long
def totalBlocks : Int
def numLocalBlocks: Int
def numRemoteBlocks: Int
def remoteFetchTime : Long
def fetchWaitTime: Long
def remoteBytesRead : Long
}
......@@ -7,27 +7,36 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashSet
import scala.collection.mutable.Queue
import io.netty.buffer.ByteBuf
import spark.Logging
import spark.Utils
import spark.SparkException
import spark.network.BufferMessage
import spark.network.ConnectionManagerId
import spark.network.netty.ShuffleCopier
import spark.serializer.Serializer
import io.netty.buffer.ByteBuf
/**
* A block fetcher iterator interface. There are two implementations:
*
* BasicBlockFetcherIterator: uses a custom-built NIO communication layer.
* NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer.
*
* Eventually we would like the two to converge and use a single NIO-based communication layer,
* but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores),
* NIO would perform poorly and thus the need for the Netty OIO one.
*/
private[storage]
trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])]
with Logging with BlockFetchTracker {
def initialize()
}
private[storage]
object BlockFetcherIterator {
// A request to fetch one or more blocks, complete with their sizes
......@@ -45,8 +54,8 @@ object BlockFetcherIterator {
class BasicBlockFetcherIterator(
private val blockManager: BlockManager,
val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
serializer: Serializer
) extends BlockFetcherIterator {
serializer: Serializer)
extends BlockFetcherIterator {
import blockManager._
......@@ -57,23 +66,24 @@ object BlockFetcherIterator {
if (blocksByAddress == null) {
throw new IllegalArgumentException("BlocksByAddress is null")
}
var totalBlocks = blocksByAddress.map(_._2.size).sum
logDebug("Getting " + totalBlocks + " blocks")
var startTime = System.currentTimeMillis
val localBlockIds = new ArrayBuffer[String]()
val remoteBlockIds = new HashSet[String]()
protected var _totalBlocks = blocksByAddress.map(_._2.size).sum
logDebug("Getting " + _totalBlocks + " blocks")
protected var startTime = System.currentTimeMillis
protected val localBlockIds = new ArrayBuffer[String]()
protected val remoteBlockIds = new HashSet[String]()
// A queue to hold our results.
val results = new LinkedBlockingQueue[FetchResult]
protected val results = new LinkedBlockingQueue[FetchResult]
// Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
// the number of bytes in flight is limited to maxBytesInFlight
val fetchRequests = new Queue[FetchRequest]
private val fetchRequests = new Queue[FetchRequest]
// Current bytes in flight from our requests
var bytesInFlight = 0L
private var bytesInFlight = 0L
def sendRequest(req: FetchRequest) {
protected def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort))
val cmId = new ConnectionManagerId(req.address.host, req.address.port)
......@@ -111,7 +121,7 @@ object BlockFetcherIterator {
}
}
def splitLocalRemoteBlocks():ArrayBuffer[FetchRequest] = {
protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]
......@@ -148,14 +158,15 @@ object BlockFetcherIterator {
remoteRequests
}
def getLocalBlocks(){
protected def getLocalBlocks() {
// Get the local blocks while remote blocks are being fetched. Note that it's okay to do
// these all at once because they will just memory-map some files, so they won't consume
// any memory that might exceed our maxBytesInFlight
for (id <- localBlockIds) {
getLocal(id) match {
case Some(iter) => {
results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
// Pass 0 as size since it's not in flight
results.put(new FetchResult(id, 0, () => iter))
logDebug("Got local block " + id)
}
case None => {
......@@ -165,7 +176,7 @@ object BlockFetcherIterator {
}
}
override def initialize(){
override def initialize() {
// Split local and remote blocks.
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
......@@ -184,15 +195,14 @@ object BlockFetcherIterator {
startTime = System.currentTimeMillis
getLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
}
//an iterator that will read fetched blocks off the queue as they arrive.
@volatile protected var resultsGotten = 0
def hasNext: Boolean = resultsGotten < totalBlocks
override def hasNext: Boolean = resultsGotten < _totalBlocks
def next(): (String, Option[Iterator[Any]]) = {
override def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1
val startFetchWait = System.currentTimeMillis()
val result = results.take()
......@@ -206,74 +216,73 @@ object BlockFetcherIterator {
(result.blockId, if (result.failed) None else Some(result.deserialize()))
}
//methods to profile the block fetching
def numLocalBlocks = localBlockIds.size
def numRemoteBlocks = remoteBlockIds.size
def remoteFetchTime = _remoteFetchTime
def fetchWaitTime = _fetchWaitTime
def remoteBytesRead = _remoteBytesRead
// Implementing BlockFetchTracker trait.
override def totalBlocks: Int = _totalBlocks
override def numLocalBlocks: Int = localBlockIds.size
override def numRemoteBlocks: Int = remoteBlockIds.size
override def remoteFetchTime: Long = _remoteFetchTime
override def fetchWaitTime: Long = _fetchWaitTime
override def remoteBytesRead: Long = _remoteBytesRead
}
// End of BasicBlockFetcherIterator
class NettyBlockFetcherIterator(
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
serializer: Serializer
) extends BasicBlockFetcherIterator(blockManager,blocksByAddress,serializer) {
serializer: Serializer)
extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
import blockManager._
val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest]
def putResult(blockId:String, blockSize:Long, blockData:ByteBuffer,
results : LinkedBlockingQueue[FetchResult]){
results.put(new FetchResult(
blockId, blockSize, () => dataDeserialize(blockId, blockData, serializer) ))
}
def startCopiers (numCopiers: Int): List [ _ <: Thread]= {
private def startCopiers(numCopiers: Int): List[_ <: Thread] = {
(for ( i <- Range(0,numCopiers) ) yield {
val copier = new Thread {
override def run(){
try {
while(!isInterrupted && !fetchRequestsSync.isEmpty) {
val copier = new Thread {
override def run(){
try {
while(!isInterrupted && !fetchRequestsSync.isEmpty) {
sendRequest(fetchRequestsSync.take())
}
} catch {
case x: InterruptedException => logInfo("Copier Interrupted")
//case _ => throw new SparkException("Exception Throw in Shuffle Copier")
}
}
} catch {
case x: InterruptedException => logInfo("Copier Interrupted")
//case _ => throw new SparkException("Exception Throw in Shuffle Copier")
}
}
copier.start
copier
}
copier.start
copier
}).toList
}
//keep this to interrupt the threads when necessary
def stopCopiers(copiers : List[_ <: Thread]) {
private def stopCopiers() {
for (copier <- copiers) {
copier.interrupt()
}
}
override def sendRequest(req: FetchRequest) {
override protected def sendRequest(req: FetchRequest) {
def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) {
val fetchResult = new FetchResult(blockId, blockSize,
() => dataDeserialize(blockId, blockData.nioBuffer, serializer))
results.put(fetchResult)
}
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.memoryBytesToString(req.size), req.address.host))
val cmId = new ConnectionManagerId(req.address.host, System.getProperty("spark.shuffle.sender.port", "6653").toInt)
val cmId = new ConnectionManagerId(
req.address.host, System.getProperty("spark.shuffle.sender.port", "6653").toInt)
val cpier = new ShuffleCopier
cpier.getBlocks(cmId,req.blocks,(blockId:String,blockSize:Long,blockData:ByteBuf) => putResult(blockId,blockSize,blockData.nioBuffer,results))
cpier.getBlocks(cmId, req.blocks, putResult)
logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
}
override def splitLocalRemoteBlocks() : ArrayBuffer[FetchRequest] = {
override protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val originalTotalBlocks = totalBlocks;
val originalTotalBlocks = _totalBlocks;
val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) {
if (address == blockManagerId) {
......@@ -293,11 +302,11 @@ object BlockFetcherIterator {
if (size > 0) {
curBlocks += ((blockId, size))
curRequestSize += size
} else if (size == 0){
} else if (size == 0) {
//here we changes the totalBlocks
totalBlocks -= 1
_totalBlocks -= 1
} else {
throw new SparkException("Negative block size "+blockId)
throw new BlockException(blockId, "Negative block size " + size)
}
if (curRequestSize >= minRequestSize) {
// Add this FetchRequest
......@@ -312,13 +321,14 @@ object BlockFetcherIterator {
}
}
}
logInfo("Getting " + totalBlocks + " non 0-byte blocks out of " + originalTotalBlocks + " blocks")
logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
originalTotalBlocks + " blocks")
remoteRequests
}
var copiers : List[_ <: Thread] = null
private var copiers: List[_ <: Thread] = null
override def initialize(){
override def initialize() {
// Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
......@@ -327,7 +337,8 @@ object BlockFetcherIterator {
}
copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt)
logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime))
logInfo("Started " + fetchRequestsSync.size + " remote gets in " +
Utils.getUsedTimeMs(startTime))
// Get Local Blocks
startTime = System.currentTimeMillis
......@@ -338,24 +349,12 @@ object BlockFetcherIterator {
override def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1
val result = results.take()
// if all the results has been retrieved
// shutdown the copiers
if (resultsGotten == totalBlocks) {
if( copiers != null )
stopCopiers(copiers)
// if all the results has been retrieved, shutdown the copiers
if (resultsGotten == _totalBlocks && copiers != null) {
stopCopiers()
}
(result.blockId, if (result.failed) None else Some(result.deserialize()))
}
}
def apply(t: String,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
serializer: Serializer): BlockFetcherIterator = {
val iter = if (t == "netty") { new NettyBlockFetcherIterator(blockManager,blocksByAddress, serializer) }
else { new BasicBlockFetcherIterator(blockManager,blocksByAddress, serializer) }
iter.initialize()
iter
}
// End of NettyBlockFetcherIterator
}
......@@ -23,8 +23,7 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam
import sun.nio.ch.DirectBuffer
private[spark]
class BlockManager(
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
val master: BlockManagerMaster,
......@@ -494,11 +493,16 @@ class BlockManager(
def getMultiple(
blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
: BlockFetcherIterator = {
if(System.getProperty("spark.shuffle.use.netty", "false").toBoolean){
return BlockFetcherIterator("netty",this, blocksByAddress, serializer)
} else {
return BlockFetcherIterator("", this, blocksByAddress, serializer)
}
val iter =
if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) {
new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
} else {
new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
}
iter.initialize()
iter
}
def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
......@@ -942,8 +946,8 @@ class BlockManager(
}
}
private[spark]
object BlockManager extends Logging {
private[spark] object BlockManager extends Logging {
val ID_GENERATOR = new IdGenerator
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment