From 0e5cc30868bcf933f2980c4cfe29abc3d8fe5887 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@cs.berkeley.edu>
Date: Tue, 7 May 2013 18:18:24 -0700
Subject: [PATCH] Cleaned up BlockManager and BlockFetcherIterator from Shane's
 PR.

---
 .../spark/storage/BlockFetchTracker.scala     |  12 +-
 .../spark/storage/BlockFetcherIterator.scala  | 167 +++++++++---------
 .../scala/spark/storage/BlockManager.scala    |  22 ++-
 3 files changed, 102 insertions(+), 99 deletions(-)

diff --git a/core/src/main/scala/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
index 993aece1f7..0718156b1b 100644
--- a/core/src/main/scala/spark/storage/BlockFetchTracker.scala
+++ b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
@@ -1,10 +1,10 @@
 package spark.storage
 
 private[spark] trait BlockFetchTracker {
-    def totalBlocks : Int
-    def numLocalBlocks: Int
-    def numRemoteBlocks: Int
-    def remoteFetchTime : Long
-    def fetchWaitTime: Long
-    def remoteBytesRead : Long
+  def totalBlocks : Int
+  def numLocalBlocks: Int
+  def numRemoteBlocks: Int
+  def remoteFetchTime : Long
+  def fetchWaitTime: Long
+  def remoteBytesRead : Long
 }
diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
index 30990d9a38..43f835237c 100644
--- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
@@ -7,27 +7,36 @@ import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashSet
 import scala.collection.mutable.Queue
 
+import io.netty.buffer.ByteBuf
+
 import spark.Logging
 import spark.Utils
 import spark.SparkException
-
 import spark.network.BufferMessage
 import spark.network.ConnectionManagerId
 import spark.network.netty.ShuffleCopier
-
 import spark.serializer.Serializer
-import io.netty.buffer.ByteBuf
 
 
+/**
+ * A block fetcher iterator interface. There are two implementations:
+ *
+ * BasicBlockFetcherIterator: uses a custom-built NIO communication layer.
+ * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer.
+ *
+ * Eventually we would like the two to converge and use a single NIO-based communication layer,
+ * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores),
+ * NIO would perform poorly and thus the need for the Netty OIO one.
+ */
+
+private[storage]
 trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])]
   with Logging with BlockFetchTracker {
-
   def initialize()
-
 }
 
 
-
+private[storage]
 object BlockFetcherIterator {
 
   // A request to fetch one or more blocks, complete with their sizes
@@ -45,8 +54,8 @@ object BlockFetcherIterator {
   class BasicBlockFetcherIterator(
       private val blockManager: BlockManager,
       val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
-      serializer: Serializer
-  ) extends BlockFetcherIterator {
+      serializer: Serializer)
+    extends BlockFetcherIterator {
 
     import blockManager._
 
@@ -57,23 +66,24 @@ object BlockFetcherIterator {
     if (blocksByAddress == null) {
       throw new IllegalArgumentException("BlocksByAddress is null")
     }
-    var totalBlocks = blocksByAddress.map(_._2.size).sum
-    logDebug("Getting " + totalBlocks + " blocks")
-    var startTime = System.currentTimeMillis
-    val localBlockIds = new ArrayBuffer[String]()
-    val remoteBlockIds = new HashSet[String]()
+
+    protected var _totalBlocks = blocksByAddress.map(_._2.size).sum
+    logDebug("Getting " + _totalBlocks + " blocks")
+    protected var startTime = System.currentTimeMillis
+    protected val localBlockIds = new ArrayBuffer[String]()
+    protected val remoteBlockIds = new HashSet[String]()
 
     // A queue to hold our results.
-    val results = new LinkedBlockingQueue[FetchResult]
+    protected val results = new LinkedBlockingQueue[FetchResult]
 
     // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
     // the number of bytes in flight is limited to maxBytesInFlight
-    val fetchRequests = new Queue[FetchRequest]
+    private val fetchRequests = new Queue[FetchRequest]
 
     // Current bytes in flight from our requests
-    var bytesInFlight = 0L
+    private var bytesInFlight = 0L
 
-    def sendRequest(req: FetchRequest) {
+    protected def sendRequest(req: FetchRequest) {
       logDebug("Sending request for %d blocks (%s) from %s".format(
         req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort))
       val cmId = new ConnectionManagerId(req.address.host, req.address.port)
@@ -111,7 +121,7 @@ object BlockFetcherIterator {
       }
     }
 
-    def splitLocalRemoteBlocks():ArrayBuffer[FetchRequest] = {
+    protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
       // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
       // at most maxBytesInFlight in order to limit the amount of data in flight.
       val remoteRequests = new ArrayBuffer[FetchRequest]
@@ -148,14 +158,15 @@ object BlockFetcherIterator {
       remoteRequests
     }
 
-    def getLocalBlocks(){
+    protected def getLocalBlocks() {
       // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
       // these all at once because they will just memory-map some files, so they won't consume
       // any memory that might exceed our maxBytesInFlight
       for (id <- localBlockIds) {
         getLocal(id) match {
           case Some(iter) => {
-            results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
+            // Pass 0 as size since it's not in flight
+            results.put(new FetchResult(id, 0, () => iter))
             logDebug("Got local block " + id)
           }
           case None => {
@@ -165,7 +176,7 @@ object BlockFetcherIterator {
       }
     }
 
-    override def initialize(){
+    override def initialize() {
       // Split local and remote blocks.
       val remoteRequests = splitLocalRemoteBlocks()
       // Add the remote requests into our queue in a random order
@@ -184,15 +195,14 @@ object BlockFetcherIterator {
       startTime = System.currentTimeMillis
       getLocalBlocks()
       logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
-
     }
 
     //an iterator that will read fetched blocks off the queue as they arrive.
     @volatile protected var resultsGotten = 0
 
-    def hasNext: Boolean = resultsGotten < totalBlocks
+    override def hasNext: Boolean = resultsGotten < _totalBlocks
 
-    def next(): (String, Option[Iterator[Any]]) = {
+    override def next(): (String, Option[Iterator[Any]]) = {
       resultsGotten += 1
       val startFetchWait = System.currentTimeMillis()
       val result = results.take()
@@ -206,74 +216,73 @@ object BlockFetcherIterator {
       (result.blockId, if (result.failed) None else Some(result.deserialize()))
     }
 
-
-    //methods to profile the block fetching
-    def numLocalBlocks = localBlockIds.size
-    def numRemoteBlocks = remoteBlockIds.size
-
-    def remoteFetchTime = _remoteFetchTime
-    def fetchWaitTime = _fetchWaitTime
-
-    def remoteBytesRead = _remoteBytesRead
-
+    // Implementing BlockFetchTracker trait.
+    override def totalBlocks: Int = _totalBlocks
+    override def numLocalBlocks: Int = localBlockIds.size
+    override def numRemoteBlocks: Int = remoteBlockIds.size
+    override def remoteFetchTime: Long = _remoteFetchTime
+    override def fetchWaitTime: Long = _fetchWaitTime
+    override def remoteBytesRead: Long = _remoteBytesRead
   }
-
+  // End of BasicBlockFetcherIterator
 
   class NettyBlockFetcherIterator(
       blockManager: BlockManager,
       blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
-      serializer: Serializer
-  ) extends BasicBlockFetcherIterator(blockManager,blocksByAddress,serializer) {
+      serializer: Serializer)
+    extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
 
     import blockManager._
 
     val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest]
 
-    def putResult(blockId:String, blockSize:Long, blockData:ByteBuffer,
-                  results : LinkedBlockingQueue[FetchResult]){
-       results.put(new FetchResult(
-          blockId, blockSize, () => dataDeserialize(blockId, blockData, serializer) ))
-    }
-
-    def startCopiers (numCopiers: Int): List [ _ <: Thread]= {
+    private def startCopiers(numCopiers: Int): List[_ <: Thread] = {
       (for ( i <- Range(0,numCopiers) ) yield {
-          val copier = new Thread {
-             override def run(){
-              try {
-               while(!isInterrupted && !fetchRequestsSync.isEmpty) {
+        val copier = new Thread {
+          override def run(){
+            try {
+              while(!isInterrupted && !fetchRequestsSync.isEmpty) {
                 sendRequest(fetchRequestsSync.take())
-               }
-              } catch {
-                case x: InterruptedException => logInfo("Copier Interrupted")
-                //case _ => throw new SparkException("Exception Throw in Shuffle Copier")
               }
-             }
+            } catch {
+              case x: InterruptedException => logInfo("Copier Interrupted")
+              //case _ => throw new SparkException("Exception Throw in Shuffle Copier")
+            }
           }
-          copier.start
-          copier
+        }
+        copier.start
+        copier
       }).toList
     }
 
     //keep this to interrupt the threads when necessary
-    def stopCopiers(copiers : List[_ <: Thread]) {
+    private def stopCopiers() {
       for (copier <- copiers) {
         copier.interrupt()
       }
     }
 
-    override def sendRequest(req: FetchRequest) {
+    override protected def sendRequest(req: FetchRequest) {
+
+      def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) {
+        val fetchResult = new FetchResult(blockId, blockSize,
+          () => dataDeserialize(blockId, blockData.nioBuffer, serializer))
+        results.put(fetchResult)
+      }
+
       logDebug("Sending request for %d blocks (%s) from %s".format(
         req.blocks.size, Utils.memoryBytesToString(req.size), req.address.host))
-      val cmId = new ConnectionManagerId(req.address.host, System.getProperty("spark.shuffle.sender.port", "6653").toInt)
+      val cmId = new ConnectionManagerId(
+        req.address.host, System.getProperty("spark.shuffle.sender.port", "6653").toInt)
       val cpier = new ShuffleCopier
-      cpier.getBlocks(cmId,req.blocks,(blockId:String,blockSize:Long,blockData:ByteBuf) => putResult(blockId,blockSize,blockData.nioBuffer,results))
+      cpier.getBlocks(cmId, req.blocks, putResult)
       logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
     }
 
-    override def splitLocalRemoteBlocks() : ArrayBuffer[FetchRequest] = {
+    override protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
       // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
       // at most maxBytesInFlight in order to limit the amount of data in flight.
-      val originalTotalBlocks = totalBlocks;
+      val originalTotalBlocks = _totalBlocks;
       val remoteRequests = new ArrayBuffer[FetchRequest]
       for ((address, blockInfos) <- blocksByAddress) {
         if (address == blockManagerId) {
@@ -293,11 +302,11 @@ object BlockFetcherIterator {
             if (size > 0) {
               curBlocks += ((blockId, size))
               curRequestSize += size
-            } else if (size == 0){
+            } else if (size == 0) {
               //here we changes the totalBlocks
-              totalBlocks -= 1
+              _totalBlocks -= 1
             } else {
-              throw new SparkException("Negative block size "+blockId)
+              throw new BlockException(blockId, "Negative block size " + size)
             }
             if (curRequestSize >= minRequestSize) {
               // Add this FetchRequest
@@ -312,13 +321,14 @@ object BlockFetcherIterator {
           }
         }
       }
-      logInfo("Getting " + totalBlocks + " non 0-byte blocks out of " + originalTotalBlocks + " blocks")
+      logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
+        originalTotalBlocks + " blocks")
       remoteRequests
     }
 
-    var copiers : List[_ <: Thread] = null
+    private var copiers: List[_ <: Thread] = null
 
-    override def initialize(){
+    override def initialize() {
       // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks
       val remoteRequests = splitLocalRemoteBlocks()
       // Add the remote requests into our queue in a random order
@@ -327,7 +337,8 @@ object BlockFetcherIterator {
       }
 
       copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt)
-      logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime))
+      logInfo("Started " + fetchRequestsSync.size + " remote gets in " +
+        Utils.getUsedTimeMs(startTime))
 
       // Get Local Blocks
       startTime = System.currentTimeMillis
@@ -338,24 +349,12 @@ object BlockFetcherIterator {
     override def next(): (String, Option[Iterator[Any]]) = {
       resultsGotten += 1
       val result = results.take()
-      // if all the results has been retrieved
-      // shutdown the copiers
-      if (resultsGotten == totalBlocks) {
-        if( copiers != null )
-          stopCopiers(copiers)
+      // if all the results has been retrieved, shutdown the copiers
+      if (resultsGotten == _totalBlocks && copiers != null) {
+        stopCopiers()
       }
       (result.blockId, if (result.failed) None else Some(result.deserialize()))
     }
   }
-
-  def apply(t: String,
-            blockManager: BlockManager,
-            blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
-            serializer: Serializer):  BlockFetcherIterator = {
-    val iter = if (t == "netty") { new NettyBlockFetcherIterator(blockManager,blocksByAddress, serializer) }
-              else { new BasicBlockFetcherIterator(blockManager,blocksByAddress, serializer) }
-    iter.initialize()
-    iter
-  }
+  // End of NettyBlockFetcherIterator
 }
-
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index a189c1a025..e0dec3a8bb 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -23,8 +23,7 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam
 import sun.nio.ch.DirectBuffer
 
 
-private[spark]
-class BlockManager(
+private[spark] class BlockManager(
     executorId: String,
     actorSystem: ActorSystem,
     val master: BlockManagerMaster,
@@ -494,11 +493,16 @@ class BlockManager(
   def getMultiple(
     blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
       : BlockFetcherIterator = {
-    if(System.getProperty("spark.shuffle.use.netty", "false").toBoolean){
-      return BlockFetcherIterator("netty",this, blocksByAddress, serializer)
-    } else {
-      return BlockFetcherIterator("", this, blocksByAddress, serializer)
-    }
+
+    val iter =
+      if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) {
+        new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
+      } else {
+        new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
+      }
+
+    iter.initialize()
+    iter
   }
 
   def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
@@ -942,8 +946,8 @@ class BlockManager(
   }
 }
 
-private[spark]
-object BlockManager extends Logging {
+
+private[spark] object BlockManager extends Logging {
 
   val ID_GENERATOR = new IdGenerator
 
-- 
GitLab