From 540b2a4eabe0bad2455f5140c4ad6a315e37cc3f Mon Sep 17 00:00:00 2001
From: Wenchen Fan <cloud0fan@outlook.com>
Date: Wed, 18 Mar 2015 19:43:04 -0700
Subject: [PATCH] [SPARK-6394][Core] cleanup BlockManager companion object and
 improve the getCacheLocs method in DAGScheduler

The current implementation include searching a HashMap many times, we can avoid this.
Actually if you look into `BlockManager.blockIdsToBlockManagers`, the core function call is [this](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/storage/BlockManager.scala#L1258), so we can call `blockManagerMaster.getLocations` directly and avoid building a HashMap.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #5043 from cloud-fan/small and squashes the following commits:

e959d12 [Wenchen Fan] fix style
203c493 [Wenchen Fan] some cleanup in BlockManager companion object
d409099 [Wenchen Fan] address rxin's comment
faec999 [Wenchen Fan] add regression test
2fb57aa [Wenchen Fan] imporve the getCacheLocs method
---
 .../apache/spark/scheduler/DAGScheduler.scala | 11 +++++-----
 .../apache/spark/storage/BlockManager.scala   | 22 ++++---------------
 .../spark/scheduler/DAGSchedulerSuite.scala   | 12 ++++++++++
 3 files changed, 22 insertions(+), 23 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index e4170a55b7..1021172e6a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -104,7 +104,7 @@ class DAGScheduler(
    *
    * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454).
    */
-  private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
+  private val cacheLocs = new HashMap[Int, Seq[Seq[TaskLocation]]]
 
   // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with
   // every task. When we detect a node failing, we note the current epoch number and failed
@@ -188,14 +188,15 @@ class DAGScheduler(
     eventProcessLoop.post(TaskSetFailed(taskSet, reason))
   }
 
-  private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = cacheLocs.synchronized {
+  private[scheduler]
+  def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized {
     // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times
     if (!cacheLocs.contains(rdd.id)) {
       val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
-      val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
-      cacheLocs(rdd.id) = blockIds.map { id =>
-        locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
+      val locs: Seq[Seq[TaskLocation]] = blockManagerMaster.getLocations(blockIds).map { bms =>
+        bms.map(bm => TaskLocation(bm.host, bm.executorId))
       }
+      cacheLocs(rdd.id) = locs
     }
     cacheLocs(rdd.id)
   }
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index c8b7763f03..80d66e5913 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -1245,10 +1245,10 @@ private[spark] object BlockManager extends Logging {
     }
   }
 
-  def blockIdsToBlockManagers(
+  def blockIdsToHosts(
       blockIds: Array[BlockId],
       env: SparkEnv,
-      blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[BlockManagerId]] = {
+      blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = {
 
     // blockManagerMaster != null is used in tests
     assert(env != null || blockManagerMaster != null)
@@ -1258,24 +1258,10 @@ private[spark] object BlockManager extends Logging {
       blockManagerMaster.getLocations(blockIds)
     }
 
-    val blockManagers = new HashMap[BlockId, Seq[BlockManagerId]]
+    val blockManagers = new HashMap[BlockId, Seq[String]]
     for (i <- 0 until blockIds.length) {
-      blockManagers(blockIds(i)) = blockLocations(i)
+      blockManagers(blockIds(i)) = blockLocations(i).map(_.host)
     }
     blockManagers.toMap
   }
-
-  def blockIdsToExecutorIds(
-      blockIds: Array[BlockId],
-      env: SparkEnv,
-      blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = {
-    blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId))
-  }
-
-  def blockIdsToHosts(
-      blockIds: Array[BlockId],
-      env: SparkEnv,
-      blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = {
-    blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host))
-  }
 }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 30119ce5d4..63360a0f18 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -322,6 +322,18 @@ class DAGSchedulerSuite extends FunSuiteLike  with BeforeAndAfter with LocalSpar
     assertDataStructuresEmpty
   }
 
+  test("regression test for getCacheLocs") {
+    val rdd = new MyRDD(sc, 3, Nil)
+    cacheLocations(rdd.id -> 0) =
+      Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
+    cacheLocations(rdd.id -> 1) =
+      Seq(makeBlockManagerId("hostB"), makeBlockManagerId("hostC"))
+    cacheLocations(rdd.id -> 2) =
+      Seq(makeBlockManagerId("hostC"), makeBlockManagerId("hostD"))
+    val locs = scheduler.getCacheLocs(rdd).map(_.map(_.host))
+    assert(locs === Seq(Seq("hostA", "hostB"), Seq("hostB", "hostC"), Seq("hostC", "hostD")))
+  }
+
   test("avoid exponential blowup when getting preferred locs list") {
     // Build up a complex dependency graph with repeated zip operations, without preferred locations.
     var rdd: RDD[_] = new MyRDD(sc, 1, Nil)
-- 
GitLab