diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 9eb8d4850169677492bd37d00cc1ab1057822ae3..8b33319d02df111a3da07654a6b4f8ad124a6b89 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -43,7 +43,7 @@ private[spark] class Pool(
   var runningTasks = 0
 
   var priority = 0
-  var stageId = 0
+  var stageId = -1
   var name = poolName
   var parent: Pool = null
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
index c4f555bfe142b710cda29fc85307410414cf937a..a4e86538f99df8de7686e7e05ddfbc8f69aa6ff9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
@@ -36,8 +36,23 @@ private[spark] trait SchedulableBuilder {
 
   def addTaskSetManager(manager: Schedulable, properties: Properties)
 
-  def getTaskSetManagers(stageId: Int): Iterable[Schedulable] = {
-    rootPool.schedulableQueue.filter(_.stageId == stageId)
+  /**
+   * Find the TaskSetManager for the given stage. In fair scheduler, this function examines
+   * all the pools to find the TaskSetManager.
+   */
+  def getTaskSetManagers(stageId: Int): Option[TaskSetManager] = {
+    def getTsm(pool: Pool): Option[TaskSetManager] = {
+      pool.schedulableQueue.foreach {
+        case tsm: TaskSetManager =>
+          if (tsm.stageId == stageId) {
+            return Some(tsm)
+          }
+        case pool: Pool =>
+          getTsm(pool)
+      }
+      return None
+    }
+    getTsm(rootPool)
   }
 }
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index 031d0b1ef7d602e07dea86e5521e5fe6ccd8f0b9..250dec5126eb2df2436b8a8425eabb1c02332279 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -172,7 +172,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
 
   override def cancelTasks(stageId: Int): Unit = synchronized {
     logInfo("Cancelling stage " + stageId)
-    schedulableBuilder.getTaskSetManagers(stageId).foreach { case tsm: TaskSetManager =>
+    schedulableBuilder.getTaskSetManagers(stageId).foreach { tsm =>
       // There are two possible cases here:
       // 1. The task set manager has been created and some tasks have been scheduled.
       //    In this case, send a kill signal to the executors to kill the task.