diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala
index 3b77a1e12cc457f4194a8bbdd939f26d37b25171..aa9c25cb5c8c6c88f2b01e6de8b3180a33b37e3c 100644
--- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala
@@ -41,11 +41,11 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen
     conf.getInt("spark.ui.retainedStages", SparkUI.DEFAULT_RETAINED_STAGES)
 
   /** Return the graph metadata for the given stage, or None if no such information exists. */
-  def getOperationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = {
-    val stageIds = jobIdToStageIds.get(jobId).getOrElse { Seq.empty }
-    val graphs = stageIds.flatMap { sid => stageIdToGraph.get(sid) }
+  def getOperationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = synchronized {
+    val _stageIds = jobIdToStageIds.get(jobId).getOrElse { Seq.empty }
+    val graphs = _stageIds.flatMap { sid => stageIdToGraph.get(sid) }
     // If the metadata for some stages have been removed, do not bother rendering this job
-    if (stageIds.size != graphs.size) {
+    if (_stageIds.size != graphs.size) {
       Seq.empty
     } else {
       graphs
@@ -53,16 +53,29 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen
   }
 
   /** Return the graph metadata for the given stage, or None if no such information exists. */
-  def getOperationGraphForStage(stageId: Int): Option[RDDOperationGraph] = {
+  def getOperationGraphForStage(stageId: Int): Option[RDDOperationGraph] = synchronized {
     stageIdToGraph.get(stageId)
   }
 
   /** On job start, construct a RDDOperationGraph for each stage in the job for display later. */
   override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized {
     val jobId = jobStart.jobId
+    val stageInfos = jobStart.stageInfos
+
     jobIds += jobId
     jobIdToStageIds(jobId) = jobStart.stageInfos.map(_.stageId).sorted
 
+    stageInfos.foreach { stageInfo =>
+      stageIds += stageInfo.stageId
+      stageIdToGraph(stageInfo.stageId) = RDDOperationGraph.makeOperationGraph(stageInfo)
+      // Remove state for old stages
+      if (stageIds.size >= retainedStages) {
+        val toRemove = math.max(retainedStages / 10, 1)
+        stageIds.take(toRemove).foreach { id => stageIdToGraph.remove(id) }
+        stageIds.trimStart(toRemove)
+      }
+    }
+
     // Remove state for old jobs
     if (jobIds.size >= retainedJobs) {
       val toRemove = math.max(retainedJobs / 10, 1)
@@ -71,15 +84,4 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen
     }
   }
 
-  /** Remove graph metadata for old stages */
-  override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized {
-    val stageInfo = stageSubmitted.stageInfo
-    stageIds += stageInfo.stageId
-    stageIdToGraph(stageInfo.stageId) = RDDOperationGraph.makeOperationGraph(stageInfo)
-    if (stageIds.size >= retainedStages) {
-      val toRemove = math.max(retainedStages / 10, 1)
-      stageIds.take(toRemove).foreach { id => stageIdToGraph.remove(id) }
-      stageIds.trimStart(toRemove)
-    }
-  }
 }
diff --git a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala
index 619b38ac02676f2ae259f6ca48aa4a9bfe530c5d..c659fc1e8b9a9ca4297ccf0b706efad5e7a214f5 100644
--- a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala
@@ -31,7 +31,6 @@ class RDDOperationGraphListenerSuite extends FunSuite {
     assert(numStages > 0, "I will not run a job with 0 stages for you.")
     val stageInfos = (0 until numStages).map { _ =>
       val stageInfo = new StageInfo(stageIdCounter, 0, "s", 0, Seq.empty, Seq.empty, "d")
-      listener.onStageSubmitted(new SparkListenerStageSubmitted(stageInfo))
       stageIdCounter += 1
       stageInfo
     }