From 69f750228f3ec8537a93da08e712596fa8004143 Mon Sep 17 00:00:00 2001
From: Andrew Or <andrewor14@gmail.com>
Date: Wed, 14 May 2014 00:54:33 -0700
Subject: [PATCH] [SPARK-1769] Executor loss causes NPE race condition

This PR replaces the Schedulable data structures in Pool.scala with thread-safe ones from java. Note that Scala's `with SynchronizedBuffer` trait is soon to be deprecated in 2.11 because it is ["inherently unreliable"](http://www.scala-lang.org/api/2.11.0/index.html#scala.collection.mutable.SynchronizedBuffer). We should slowly drift away from `SynchronizedBuffer` in other places too.

Note that this PR introduces an API-breaking change; `sc.getAllPools` now returns an Array rather than an ArrayBuffer. This is because we want this method to return an immutable copy rather than one may potentially confuse the user if they try to modify the copy, which takes no effect on the original data structure.

Author: Andrew Or <andrewor14@gmail.com>

Closes #762 from andrewor14/pool-npe and squashes the following commits:

383e739 [Andrew Or] JavaConverters -> JavaConversions
3f32981 [Andrew Or] Merge branch 'master' of github.com:apache/spark into pool-npe
769be19 [Andrew Or] Assorted minor changes
2189247 [Andrew Or] Merge branch 'master' of github.com:apache/spark into pool-npe
05ad9e9 [Andrew Or] Fix test - contains is not the same as containsKey
0921ea0 [Andrew Or] var -> val
07d720c [Andrew Or] Synchronize Schedulable data structures
---
 .../scala/org/apache/spark/SparkContext.scala | 20 +++++++-----
 .../org/apache/spark/scheduler/Pool.scala     | 31 ++++++++++---------
 .../apache/spark/scheduler/Schedulable.scala  |  6 ++--
 .../spark/scheduler/TaskSchedulerImpl.scala   |  2 +-
 .../scheduler/TaskSchedulerImplSuite.scala    |  2 +-
 5 files changed, 35 insertions(+), 26 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index c43b4fd6d9..032b3d744c 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -17,15 +17,17 @@
 
 package org.apache.spark
 
+import scala.language.implicitConversions
+
 import java.io._
 import java.net.URI
 import java.util.concurrent.atomic.AtomicInteger
 import java.util.{Properties, UUID}
 import java.util.UUID.randomUUID
 import scala.collection.{Map, Set}
+import scala.collection.JavaConversions._
 import scala.collection.generic.Growable
-import scala.collection.mutable.{ArrayBuffer, HashMap}
-import scala.language.implicitConversions
+import scala.collection.mutable.HashMap
 import scala.reflect.{ClassTag, classTag}
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
@@ -836,18 +838,22 @@ class SparkContext(config: SparkConf) extends Logging {
   }
 
   /**
-   *  Return pools for fair scheduler
-   *  TODO(xiajunluan): We should take nested pools into account
+   * :: DeveloperApi ::
+   * Return pools for fair scheduler
    */
-  def getAllPools: ArrayBuffer[Schedulable] = {
-    taskScheduler.rootPool.schedulableQueue
+  @DeveloperApi
+  def getAllPools: Seq[Schedulable] = {
+    // TODO(xiajunluan): We should take nested pools into account
+    taskScheduler.rootPool.schedulableQueue.toSeq
   }
 
   /**
+   * :: DeveloperApi ::
    * Return the pool associated with the given name, if one exists
    */
+  @DeveloperApi
   def getPoolForName(pool: String): Option[Schedulable] = {
-    taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)
+    Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool))
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 187672c4e1..174b73221a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -17,8 +17,10 @@
 
 package org.apache.spark.scheduler
 
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
+
+import scala.collection.JavaConversions._
 import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
 
 import org.apache.spark.Logging
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
@@ -35,18 +37,15 @@ private[spark] class Pool(
   extends Schedulable
   with Logging {
 
-  var schedulableQueue = new ArrayBuffer[Schedulable]
-  var schedulableNameToSchedulable = new HashMap[String, Schedulable]
-
+  val schedulableQueue = new ConcurrentLinkedQueue[Schedulable]
+  val schedulableNameToSchedulable = new ConcurrentHashMap[String, Schedulable]
   var weight = initWeight
   var minShare = initMinShare
   var runningTasks = 0
-
   var priority = 0
 
   // A pool's stage id is used to break the tie in scheduling.
   var stageId = -1
-
   var name = poolName
   var parent: Pool = null
 
@@ -60,19 +59,20 @@ private[spark] class Pool(
   }
 
   override def addSchedulable(schedulable: Schedulable) {
-    schedulableQueue += schedulable
-    schedulableNameToSchedulable(schedulable.name) = schedulable
+    require(schedulable != null)
+    schedulableQueue.add(schedulable)
+    schedulableNameToSchedulable.put(schedulable.name, schedulable)
     schedulable.parent = this
   }
 
   override def removeSchedulable(schedulable: Schedulable) {
-    schedulableQueue -= schedulable
-    schedulableNameToSchedulable -= schedulable.name
+    schedulableQueue.remove(schedulable)
+    schedulableNameToSchedulable.remove(schedulable.name)
   }
 
   override def getSchedulableByName(schedulableName: String): Schedulable = {
-    if (schedulableNameToSchedulable.contains(schedulableName)) {
-      return schedulableNameToSchedulable(schedulableName)
+    if (schedulableNameToSchedulable.containsKey(schedulableName)) {
+      return schedulableNameToSchedulable.get(schedulableName)
     }
     for (schedulable <- schedulableQueue) {
       val sched = schedulable.getSchedulableByName(schedulableName)
@@ -95,11 +95,12 @@ private[spark] class Pool(
     shouldRevive
   }
 
-  override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+  override def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] = {
     var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
-    val sortedSchedulableQueue = schedulableQueue.sortWith(taskSetSchedulingAlgorithm.comparator)
+    val sortedSchedulableQueue =
+      schedulableQueue.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator)
     for (schedulable <- sortedSchedulableQueue) {
-      sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue()
+      sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue
     }
     sortedTaskSetQueue
   }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
index ed24eb6a54..a87ef030e6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.scheduler
 
+import java.util.concurrent.ConcurrentLinkedQueue
+
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
@@ -28,7 +30,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
 private[spark] trait Schedulable {
   var parent: Pool
   // child queues
-  def schedulableQueue: ArrayBuffer[Schedulable]
+  def schedulableQueue: ConcurrentLinkedQueue[Schedulable]
   def schedulingMode: SchedulingMode
   def weight: Int
   def minShare: Int
@@ -42,5 +44,5 @@ private[spark] trait Schedulable {
   def getSchedulableByName(name: String): Schedulable
   def executorLost(executorId: String, host: String): Unit
   def checkSpeculatableTasks(): Boolean
-  def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager]
+  def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager]
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 5a68f38bc5..ffd1d94326 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -222,7 +222,7 @@ private[spark] class TaskSchedulerImpl(
     // Build a list of tasks to assign to each worker.
     val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))
     val availableCpus = shuffledOffers.map(o => o.cores).toArray
-    val sortedTaskSets = rootPool.getSortedTaskSetQueue()
+    val sortedTaskSets = rootPool.getSortedTaskSetQueue
     for (taskSet <- sortedTaskSets) {
       logDebug("parentName: %s, name: %s, runningTasks: %s".format(
         taskSet.parent.name, taskSet.name, taskSet.runningTasks))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index a8b605c5b2..7532da88c6 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -117,7 +117,7 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin
   }
 
   def resourceOffer(rootPool: Pool): Int = {
-    val taskSetQueue = rootPool.getSortedTaskSetQueue()
+    val taskSetQueue = rootPool.getSortedTaskSetQueue
     /* Just for Test*/
     for (manager <- taskSetQueue) {
        logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(
-- 
GitLab