diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 0b6fe70bd2062c2af9bbfa909d1b2c4c1e4e81d1..937d95a934b590ded3d5eeb122a33897a0dbcd93 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -203,6 +203,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
     for (stageId <- jobData.stageIds) {
       stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage =>
         jobsUsingStage.remove(jobEnd.jobId)
+        if (jobsUsingStage.isEmpty) {
+          stageIdToActiveJobIds.remove(stageId)
+        }
         stageIdToInfo.get(stageId).foreach { stageInfo =>
           if (stageInfo.submissionTime.isEmpty) {
             // if this stage is pending, it won't complete, so mark it as "skipped":
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 6019282d2fb70fd7e133e463dbea50356a7013d7..730a4b54f5aa1104fe295750f3517ba9011fcbdd 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -88,6 +88,28 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
     listener.completedStages.map(_.stageId).toSet should be (Set(50, 49, 48, 47, 46))
   }
 
+  test("test clearing of stageIdToActiveJobs") {
+    val conf = new SparkConf()
+    conf.set("spark.ui.retainedStages", 5.toString)
+    val listener = new JobProgressListener(conf)
+    val jobId = 0
+    val stageIds = 1 to 50
+    // Start a job with 50 stages
+    listener.onJobStart(createJobStartEvent(jobId, stageIds))
+    for (stageId <- stageIds) {
+      listener.onStageSubmitted(createStageStartEvent(stageId))
+    }
+    listener.stageIdToActiveJobIds.size should be > 0
+
+    // Complete the stages and job
+    for (stageId <- stageIds) {
+      listener.onStageCompleted(createStageEndEvent(stageId, failed = false))
+    }
+    listener.onJobEnd(createJobEndEvent(jobId, false))
+    assertActiveJobsStateIsEmpty(listener)
+    listener.stageIdToActiveJobIds.size should be (0)
+  }
+
   test("test LRU eviction of jobs") {
     val conf = new SparkConf()
     conf.set("spark.ui.retainedStages", 5.toString)