From 8ff4417f70198ba2d848157f9da4e1e7e18f4fca Mon Sep 17 00:00:00 2001
From: Kousuke Saruta <sarutak@oss.nttdata.co.jp>
Date: Fri, 1 Aug 2014 00:01:30 -0700
Subject: [PATCH] [SPARK-2670] FetchFailedException should be thrown when local
 fetch has failed

Author: Kousuke Saruta <sarutak@oss.nttdata.co.jp>

Closes #1578 from sarutak/SPARK-2670 and squashes the following commits:

85c8938 [Kousuke Saruta] Removed useless results.put for fail fast
e8713cc [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2670
d353984 [Kousuke Saruta] Refined assertion messages in BlockFetcherIteratorSuite.scala
03bcb02 [Kousuke Saruta] Merge branch 'SPARK-2670' of github.com:sarutak/spark into SPARK-2670
5d05855 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2670
4fca130 [Kousuke Saruta] Added test cases for BasicBlockFetcherIterator
b7b8250 [Kousuke Saruta] Modified BasicBlockFetchIterator to fail fast when local fetch error has been occurred
a3a9be1 [Kousuke Saruta] Modified BlockFetcherIterator for SPARK-2670
460dc01 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2670
e310c0b [Kousuke Saruta] Modified BlockFetcherIterator to handle local fetch failure as fatch fail
---
 .../spark/storage/BlockFetcherIterator.scala  |  19 ++-
 .../storage/BlockFetcherIteratorSuite.scala   | 140 ++++++++++++++++++
 2 files changed, 151 insertions(+), 8 deletions(-)
 create mode 100644 core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala

diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index 69905a960a..ccf830e118 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -200,14 +200,17 @@ object BlockFetcherIterator {
       // these all at once because they will just memory-map some files, so they won't consume
       // any memory that might exceed our maxBytesInFlight
       for (id <- localBlocksToFetch) {
-        getLocalFromDisk(id, serializer) match {
-          case Some(iter) => {
-            // Pass 0 as size since it's not in flight
-            results.put(new FetchResult(id, 0, () => iter))
-            logDebug("Got local block " + id)
-          }
-          case None => {
-            throw new BlockException(id, "Could not get block " + id + " from local machine")
+        try {
+          // getLocalFromDisk never return None but throws BlockException
+          val iter = getLocalFromDisk(id, serializer).get
+          // Pass 0 as size since it's not in flight
+          results.put(new FetchResult(id, 0, () => iter))
+          logDebug("Got local block " + id)
+        } catch {
+          case e: Exception => {
+            logError(s"Error occurred while fetching local blocks", e)
+            results.put(new FetchResult(id, -1, null))
+            return
           }
         }
       }
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
new file mode 100644
index 0000000000..8dca2ebb31
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.scalatest.{FunSuite, Matchers}
+import org.scalatest.PrivateMethodTester._
+
+import org.mockito.Mockito._
+import org.mockito.Matchers.{any, eq => meq}
+import org.mockito.stubbing.Answer
+import org.mockito.invocation.InvocationOnMock
+
+import org.apache.spark._
+import org.apache.spark.storage.BlockFetcherIterator._
+import org.apache.spark.network.{ConnectionManager, ConnectionManagerId,
+                                 Message}
+
+class BlockFetcherIteratorSuite extends FunSuite with Matchers {
+
+  test("block fetch from local fails using BasicBlockFetcherIterator") {
+    val blockManager = mock(classOf[BlockManager])
+    val connManager = mock(classOf[ConnectionManager])
+    doReturn(connManager).when(blockManager).connectionManager
+    doReturn(BlockManagerId("test-client", "test-client", 1, 0)).when(blockManager).blockManagerId
+
+    doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight
+
+    val blIds = Array[BlockId](
+      ShuffleBlockId(0,0,0),
+      ShuffleBlockId(0,1,0),
+      ShuffleBlockId(0,2,0),
+      ShuffleBlockId(0,3,0),
+      ShuffleBlockId(0,4,0))
+
+    val optItr = mock(classOf[Option[Iterator[Any]]])
+    val answer = new Answer[Option[Iterator[Any]]] {
+      override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] {
+        throw new Exception
+      }
+    }
+
+    // 3rd block is going to fail
+    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any())
+    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any())
+    doAnswer(answer).when(blockManager).getLocalFromDisk(meq(blIds(2)), any())
+    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any())
+    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any())
+
+    val bmId = BlockManagerId("test-client", "test-client",1 , 0)
+    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+      (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
+    )
+
+    val iterator = new BasicBlockFetcherIterator(blockManager,
+      blocksByAddress, null)
+
+    iterator.initialize()
+
+    // 3rd getLocalFromDisk invocation should be failed
+    verify(blockManager, times(3)).getLocalFromDisk(any(), any())
+
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
+    // the 2nd element of the tuple returned by iterator.next should be defined when fetching successfully
+    assert(iterator.next._2.isDefined, "1st element should be defined but is not actually defined") 
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
+    assert(iterator.next._2.isDefined, "2nd element should be defined but is not actually defined") 
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
+    // 3rd fetch should be failed
+    assert(!iterator.next._2.isDefined, "3rd element should not be defined but is actually defined") 
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements")
+    // Don't call next() after fetching non-defined element even if thare are rest of elements in the iterator.
+    // Otherwise, BasicBlockFetcherIterator hangs up.
+  }
+
+
+  test("block fetch from local succeed using BasicBlockFetcherIterator") {
+    val blockManager = mock(classOf[BlockManager])
+    val connManager = mock(classOf[ConnectionManager])
+    doReturn(connManager).when(blockManager).connectionManager
+    doReturn(BlockManagerId("test-client", "test-client", 1, 0)).when(blockManager).blockManagerId
+
+    doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight
+
+    val blIds = Array[BlockId](
+      ShuffleBlockId(0,0,0),
+      ShuffleBlockId(0,1,0),
+      ShuffleBlockId(0,2,0),
+      ShuffleBlockId(0,3,0),
+      ShuffleBlockId(0,4,0))
+
+    val optItr = mock(classOf[Option[Iterator[Any]]])
+ 
+   // All blocks should be fetched successfully
+    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any())
+    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any())
+    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(2)), any())
+    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any())
+    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any())
+
+    val bmId = BlockManagerId("test-client", "test-client",1 , 0)
+    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+      (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
+    )
+
+    val iterator = new BasicBlockFetcherIterator(blockManager,
+      blocksByAddress, null)
+
+    iterator.initialize()
+
+    // getLocalFromDis should be invoked for all of 5 blocks
+    verify(blockManager, times(5)).getLocalFromDisk(any(), any())
+
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
+    assert(iterator.next._2.isDefined, "All elements should be defined but 1st element is not actually defined") 
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
+    assert(iterator.next._2.isDefined, "All elements should be defined but 2nd element is not actually defined") 
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
+    assert(iterator.next._2.isDefined, "All elements should be defined but 3rd element is not actually defined") 
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements")
+    assert(iterator.next._2.isDefined, "All elements should be defined but 4th element is not actually defined") 
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements")
+    assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined") 
+  }
+
+}
-- 
GitLab