diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 9eb8d4850169677492bd37d00cc1ab1057822ae3..8b33319d02df111a3da07654a6b4f8ad124a6b89 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -43,7 +43,7 @@ private[spark] class Pool( var runningTasks = 0 var priority = 0 - var stageId = 0 + var stageId = -1 var name = poolName var parent: Pool = null diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index c4f555bfe142b710cda29fc85307410414cf937a..a4e86538f99df8de7686e7e05ddfbc8f69aa6ff9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -36,8 +36,23 @@ private[spark] trait SchedulableBuilder { def addTaskSetManager(manager: Schedulable, properties: Properties) - def getTaskSetManagers(stageId: Int): Iterable[Schedulable] = { - rootPool.schedulableQueue.filter(_.stageId == stageId) + /** + * Find the TaskSetManager for the given stage. In fair scheduler, this function examines + * all the pools to find the TaskSetManager. + */ + def getTaskSetManagers(stageId: Int): Option[TaskSetManager] = { + def getTsm(pool: Pool): Option[TaskSetManager] = { + pool.schedulableQueue.foreach { + case tsm: TaskSetManager => + if (tsm.stageId == stageId) { + return Some(tsm) + } + case pool: Pool => + getTsm(pool) + } + return None + } + getTsm(rootPool) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index 031d0b1ef7d602e07dea86e5521e5fe6ccd8f0b9..250dec5126eb2df2436b8a8425eabb1c02332279 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -172,7 +172,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) override def cancelTasks(stageId: Int): Unit = synchronized { logInfo("Cancelling stage " + stageId) - schedulableBuilder.getTaskSetManagers(stageId).foreach { case tsm: TaskSetManager => + schedulableBuilder.getTaskSetManagers(stageId).foreach { tsm => // There are two possible cases here: // 1. The task set manager has been created and some tasks have been scheduled. // In this case, send a kill signal to the executors to kill the task.