diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c43b4fd6d926f1d7f509d44c38d630f7043671a8..032b3d744c619cd8a736d6148bf6c347fb248dfd 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -17,15 +17,17 @@ package org.apache.spark +import scala.language.implicitConversions + import java.io._ import java.net.URI import java.util.concurrent.atomic.AtomicInteger import java.util.{Properties, UUID} import java.util.UUID.randomUUID import scala.collection.{Map, Set} +import scala.collection.JavaConversions._ import scala.collection.generic.Growable -import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.language.implicitConversions +import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -836,18 +838,22 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Return pools for fair scheduler - * TODO(xiajunluan): We should take nested pools into account + * :: DeveloperApi :: + * Return pools for fair scheduler */ - def getAllPools: ArrayBuffer[Schedulable] = { - taskScheduler.rootPool.schedulableQueue + @DeveloperApi + def getAllPools: Seq[Schedulable] = { + // TODO(xiajunluan): We should take nested pools into account + taskScheduler.rootPool.schedulableQueue.toSeq } /** + * :: DeveloperApi :: * Return the pool associated with the given name, if one exists */ + @DeveloperApi def getPoolForName(pool: String): Option[Schedulable] = { - taskScheduler.rootPool.schedulableNameToSchedulable.get(pool) + Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 187672c4e19e7935bb3053e1559c01b4cd42e13b..174b73221afc0038599270d25514f7e3336a7e5e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -17,8 +17,10 @@ package org.apache.spark.scheduler +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} + +import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap import org.apache.spark.Logging import org.apache.spark.scheduler.SchedulingMode.SchedulingMode @@ -35,18 +37,15 @@ private[spark] class Pool( extends Schedulable with Logging { - var schedulableQueue = new ArrayBuffer[Schedulable] - var schedulableNameToSchedulable = new HashMap[String, Schedulable] - + val schedulableQueue = new ConcurrentLinkedQueue[Schedulable] + val schedulableNameToSchedulable = new ConcurrentHashMap[String, Schedulable] var weight = initWeight var minShare = initMinShare var runningTasks = 0 - var priority = 0 // A pool's stage id is used to break the tie in scheduling. var stageId = -1 - var name = poolName var parent: Pool = null @@ -60,19 +59,20 @@ private[spark] class Pool( } override def addSchedulable(schedulable: Schedulable) { - schedulableQueue += schedulable - schedulableNameToSchedulable(schedulable.name) = schedulable + require(schedulable != null) + schedulableQueue.add(schedulable) + schedulableNameToSchedulable.put(schedulable.name, schedulable) schedulable.parent = this } override def removeSchedulable(schedulable: Schedulable) { - schedulableQueue -= schedulable - schedulableNameToSchedulable -= schedulable.name + schedulableQueue.remove(schedulable) + schedulableNameToSchedulable.remove(schedulable.name) } override def getSchedulableByName(schedulableName: String): Schedulable = { - if (schedulableNameToSchedulable.contains(schedulableName)) { - return schedulableNameToSchedulable(schedulableName) + if (schedulableNameToSchedulable.containsKey(schedulableName)) { + return schedulableNameToSchedulable.get(schedulableName) } for (schedulable <- schedulableQueue) { val sched = schedulable.getSchedulableByName(schedulableName) @@ -95,11 +95,12 @@ private[spark] class Pool( shouldRevive } - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + override def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] = { var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] - val sortedSchedulableQueue = schedulableQueue.sortWith(taskSetSchedulingAlgorithm.comparator) + val sortedSchedulableQueue = + schedulableQueue.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator) for (schedulable <- sortedSchedulableQueue) { - sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue() + sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue } sortedTaskSetQueue } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala index ed24eb6a549dd274a46f60d4fef84b5f2749018b..a87ef030e69c24a5dcf5ebead93b1c5a3d5490c4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.util.concurrent.ConcurrentLinkedQueue + import scala.collection.mutable.ArrayBuffer import org.apache.spark.scheduler.SchedulingMode.SchedulingMode @@ -28,7 +30,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode private[spark] trait Schedulable { var parent: Pool // child queues - def schedulableQueue: ArrayBuffer[Schedulable] + def schedulableQueue: ConcurrentLinkedQueue[Schedulable] def schedulingMode: SchedulingMode def weight: Int def minShare: Int @@ -42,5 +44,5 @@ private[spark] trait Schedulable { def getSchedulableByName(name: String): Schedulable def executorLost(executorId: String, host: String): Unit def checkSpeculatableTasks(): Boolean - def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] + def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 5a68f38bc5844f941788c8382fdd00048646ed8a..ffd1d9432682b802fa965034e37cb976beeba1ba 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -222,7 +222,7 @@ private[spark] class TaskSchedulerImpl( // Build a list of tasks to assign to each worker. val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores)) val availableCpus = shuffledOffers.map(o => o.cores).toArray - val sortedTaskSets = rootPool.getSortedTaskSetQueue() + val sortedTaskSets = rootPool.getSortedTaskSetQueue for (taskSet <- sortedTaskSets) { logDebug("parentName: %s, name: %s, runningTasks: %s".format( taskSet.parent.name, taskSet.name, taskSet.runningTasks)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index a8b605c5b212a24257bcfb6976c04370be726cb3..7532da88c60654714cfc5c590b4ef4cfb37cb486 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -117,7 +117,7 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin } def resourceOffer(rootPool: Pool): Int = { - val taskSetQueue = rootPool.getSortedTaskSetQueue() + val taskSetQueue = rootPool.getSortedTaskSetQueue /* Just for Test*/ for (manager <- taskSetQueue) { logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(