From ac480fd977e0de97bcfe646e39feadbd239c1c29 Mon Sep 17 00:00:00 2001
From: Shivaram Venkataraman <shivaram@eecs.berkeley.edu>
Date: Thu, 6 Jun 2013 16:34:27 -0700
Subject: [PATCH] Clean up variables and counters in BlockFetcherIterator

---
 .../spark/storage/BlockFetcherIterator.scala  | 54 +++++++++++--------
 1 file changed, 31 insertions(+), 23 deletions(-)

diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
index 843069239c..bb78207c9f 100644
--- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
@@ -67,12 +67,20 @@ object BlockFetcherIterator {
       throw new IllegalArgumentException("BlocksByAddress is null")
     }
 
-    protected var _totalBlocks = blocksByAddress.map(_._2.size).sum
-    logDebug("Getting " + _totalBlocks + " blocks")
+    // Total number blocks fetched (local + remote). Also number of FetchResults expected
+    protected var _numBlocksToFetch = 0
+
     protected var startTime = System.currentTimeMillis
-    protected val localBlockIds = new ArrayBuffer[String]()
-    protected val localNonZeroBlocks = new ArrayBuffer[String]()
-    protected val remoteBlockIds = new HashSet[String]()
+
+    // This represents the number of local blocks, also counting zero-sized blocks
+    private var numLocal = 0
+    // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
+    protected val localBlocksToFetch = new ArrayBuffer[String]()
+
+    // This represents the number of remote blocks, also counting zero-sized blocks
+    private var numRemote = 0
+    // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
+    protected val remoteBlocksToFetch = new HashSet[String]()
 
     // A queue to hold our results.
     protected val results = new LinkedBlockingQueue[FetchResult]
@@ -125,15 +133,15 @@ object BlockFetcherIterator {
     protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
       // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
       // at most maxBytesInFlight in order to limit the amount of data in flight.
-      val originalTotalBlocks = _totalBlocks
       val remoteRequests = new ArrayBuffer[FetchRequest]
       for ((address, blockInfos) <- blocksByAddress) {
         if (address == blockManagerId) {
-          localBlockIds ++= blockInfos.map(_._1)
-          localNonZeroBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
-          _totalBlocks -= (localBlockIds.size - localNonZeroBlocks.size)
+          numLocal = blockInfos.size
+          // Filter out zero-sized blocks
+          localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
+          _numBlocksToFetch += localBlocksToFetch.size
         } else {
-          remoteBlockIds ++= blockInfos.map(_._1)
+          numRemote += blockInfos.size
           // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
           // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
           // nodes, rather than blocking on reading output from one node.
@@ -147,10 +155,10 @@ object BlockFetcherIterator {
             // Skip empty blocks
             if (size > 0) {
               curBlocks += ((blockId, size))
+              remoteBlocksToFetch += blockId
+              _numBlocksToFetch += 1
               curRequestSize += size
-            } else if (size == 0) {
-              _totalBlocks -= 1
-            } else {
+            } else if (size < 0) {
               throw new BlockException(blockId, "Negative block size " + size)
             }
             if (curRequestSize >= minRequestSize) {
@@ -166,8 +174,8 @@ object BlockFetcherIterator {
           }
         }
       }
-      logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
-        originalTotalBlocks + " blocks")
+      logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " +
+        totalBlocks + " blocks")
       remoteRequests
     }
 
@@ -175,7 +183,7 @@ object BlockFetcherIterator {
       // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
       // these all at once because they will just memory-map some files, so they won't consume
       // any memory that might exceed our maxBytesInFlight
-      for (id <- localNonZeroBlocks) {
+      for (id <- localBlocksToFetch) {
         getLocalFromDisk(id, serializer) match {
           case Some(iter) => {
             // Pass 0 as size since it's not in flight
@@ -201,7 +209,7 @@ object BlockFetcherIterator {
         sendRequest(fetchRequests.dequeue())
       }
 
-      val numGets = remoteBlockIds.size - fetchRequests.size
+      val numGets = remoteRequests.size - fetchRequests.size
       logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
 
       // Get Local Blocks
@@ -213,7 +221,7 @@ object BlockFetcherIterator {
     //an iterator that will read fetched blocks off the queue as they arrive.
     @volatile protected var resultsGotten = 0
 
-    override def hasNext: Boolean = resultsGotten < _totalBlocks
+    override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
 
     override def next(): (String, Option[Iterator[Any]]) = {
       resultsGotten += 1
@@ -230,9 +238,9 @@ object BlockFetcherIterator {
     }
 
     // Implementing BlockFetchTracker trait.
-    override def totalBlocks: Int = _totalBlocks
-    override def numLocalBlocks: Int = localBlockIds.size
-    override def numRemoteBlocks: Int = remoteBlockIds.size
+    override def totalBlocks: Int = numLocal + numRemote
+    override def numLocalBlocks: Int = numLocal
+    override def numRemoteBlocks: Int = numRemote
     override def remoteFetchTime: Long = _remoteFetchTime
     override def fetchWaitTime: Long = _fetchWaitTime
     override def remoteBytesRead: Long = _remoteBytesRead
@@ -294,7 +302,7 @@ object BlockFetcherIterator {
     private var copiers: List[_ <: Thread] = null
 
     override def initialize() {
-      // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks
+      // Split Local Remote Blocks and set numBlocksToFetch
       val remoteRequests = splitLocalRemoteBlocks()
       // Add the remote requests into our queue in a random order
       for (request <- Utils.randomize(remoteRequests)) {
@@ -316,7 +324,7 @@ object BlockFetcherIterator {
       val result = results.take()
       // if all the results has been retrieved, shutdown the copiers
       // NO need to stop the copiers if we got all the blocks ?
-      // if (resultsGotten == _totalBlocks && copiers != null) {
+      // if (resultsGotten == _numBlocksToFetch && copiers != null) {
       //   stopCopiers()
       // }
       (result.blockId, if (result.failed) None else Some(result.deserialize()))
-- 
GitLab