From 1d62f8aca82601506c44b6fd852f4faf3602d7e2 Mon Sep 17 00:00:00 2001
From: Wenchen Fan <wenchen@databricks.com>
Date: Sat, 27 May 2017 10:57:43 +0800
Subject: [PATCH] [SPARK-19659][CORE][FOLLOW-UP] Fetch big blocks to disk when
 shuffle-read

## What changes were proposed in this pull request?

This PR includes some minor improvement for the comments and tests in https://github.com/apache/spark/pull/16989

## How was this patch tested?

N/A

Author: Wenchen Fan <wenchen@databricks.com>

Closes #18117 from cloud-fan/follow.
---
 .../storage/ShuffleBlockFetcherIterator.scala |  9 ++--
 .../ShuffleBlockFetcherIteratorSuite.scala    | 50 ++++++++++---------
 2 files changed, 31 insertions(+), 28 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index ee35060926..bded3a1e4e 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -214,11 +214,12 @@ final class ShuffleBlockFetcherIterator(
       }
     }
 
-    // Shuffle remote blocks to disk when the request is too large.
-    // TODO: Encryption and compression should be considered.
+    // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is
+    // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
+    // the data and write it to file directly.
     if (req.size > maxReqSizeShuffleToMem) {
-      val shuffleFiles = blockIds.map {
-        bId => blockManager.diskBlockManager.createTempLocalBlock()._2
+      val shuffleFiles = blockIds.map { _ =>
+        blockManager.diskBlockManager.createTempLocalBlock()._2
       }.toArray
       shuffleFilesSet ++= shuffleFiles
       shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 1f813a909f..559b3faab8 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -36,6 +36,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
 import org.apache.spark.network.shuffle.BlockFetchingListener
 import org.apache.spark.network.util.LimitedInputStream
 import org.apache.spark.shuffle.FetchFailedException
+import org.apache.spark.util.Utils
 
 
 class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
@@ -420,9 +421,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     doReturn(localBmId).when(blockManager).blockManagerId
 
     val diskBlockManager = mock(classOf[DiskBlockManager])
+    val tmpDir = Utils.createTempDir()
     doReturn{
-      var blockId = new TempLocalBlockId(UUID.randomUUID())
-      (blockId, new File(blockId.name))
+      val blockId = TempLocalBlockId(UUID.randomUUID())
+      (blockId, new File(tmpDir, blockId.name))
     }.when(diskBlockManager).createTempLocalBlock()
     doReturn(diskBlockManager).when(blockManager).diskBlockManager
 
@@ -443,34 +445,34 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
         }
       })
 
+    def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = {
+      // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the
+      // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks
+      // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here.
+      new ShuffleBlockFetcherIterator(
+        TaskContext.empty(),
+        transfer,
+        blockManager,
+        blocksByAddress,
+        (_, in) => in,
+        maxBytesInFlight = Int.MaxValue,
+        maxReqsInFlight = Int.MaxValue,
+        maxReqSizeShuffleToMem = 200,
+        detectCorrupt = true)
+    }
+
     val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
       (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq))
-    // Set maxReqSizeShuffleToMem to be 200.
-    val iterator1 = new ShuffleBlockFetcherIterator(
-      TaskContext.empty(),
-      transfer,
-      blockManager,
-      blocksByAddress1,
-      (_, in) => in,
-      Int.MaxValue,
-      Int.MaxValue,
-      200,
-      true)
+    fetchShuffleBlock(blocksByAddress1)
+    // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch
+    // shuffle block to disk.
     assert(shuffleFiles === null)
 
     val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
       (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq))
-    // Set maxReqSizeShuffleToMem to be 200.
-    val iterator2 = new ShuffleBlockFetcherIterator(
-      TaskContext.empty(),
-      transfer,
-      blockManager,
-      blocksByAddress2,
-      (_, in) => in,
-      Int.MaxValue,
-      Int.MaxValue,
-      200,
-      true)
+    fetchShuffleBlock(blocksByAddress2)
+    // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch
+    // shuffle block to disk.
     assert(shuffleFiles != null)
   }
 }
-- 
GitLab