diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 1d31fce4c697b64f72cd9308e8f3a8f264b712ea..730f9806e518e79521038bcfd7dc54f722c2fe30 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -282,7 +282,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
     ) {
       jobData.numActiveStages -= 1
       if (stage.failureReason.isEmpty) {
-        jobData.completedStageIndices.add(stage.stageId)
+        if (!stage.submissionTime.isEmpty) {
+          jobData.completedStageIndices.add(stage.stageId)
+        }
       } else {
         jobData.numFailedStages += 1
       }
@@ -315,6 +317,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
       jobData <- jobIdToData.get(jobId)
     ) {
       jobData.numActiveStages += 1
+
+      // If a stage retries again, it should be removed from completedStageIndices set
+      jobData.completedStageIndices.remove(stage.stageId)
     }
   }
 
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index 3d96113aa5fe92be638217283ee4a1e48621b396..f008d401806113a7cc5cadf626dbdc3c35494fe5 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -22,6 +22,7 @@ import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
 import org.apache.spark.util.collection.OpenHashSet
 
+import scala.collection.mutable
 import scala.collection.mutable.HashMap
 
 private[spark] object UIData {
@@ -63,7 +64,7 @@ private[spark] object UIData {
     /* Stages */
     var numActiveStages: Int = 0,
     // This needs to be a set instead of a simple count to prevent double-counting of rerun stages:
-    var completedStageIndices: OpenHashSet[Int] = new OpenHashSet[Int](),
+    var completedStageIndices: mutable.HashSet[Int] = new mutable.HashSet[Int](),
     var numSkippedStages: Int = 0,
     var numFailedStages: Int = 0
   )