From 563bfcc1ab1b1c79b1845230c8c600db85a08fe3 Mon Sep 17 00:00:00 2001
From: Andrew Or <andrew@databricks.com>
Date: Mon, 18 May 2015 10:59:35 -0700
Subject: [PATCH] [SPARK-7627] [SPARK-7472] DAG visualization: style skipped
 stages

This patch fixes two things:

**SPARK-7627.** Cached RDDs no longer light up on the job page. This is a simple fix.
**SPARK-7472.** Display skipped stages differently from normal stages.

The latter is a major UX issue. Because we link the job viz to the stage viz even for skipped stages, the user may inadvertently click into the stage page of a skipped stage, which is empty.

-------------------
<img src="https://cloud.githubusercontent.com/assets/2133137/7675241/de1a3da6-fcea-11e4-8101-88055cef78c5.png" width="300px" />

Author: Andrew Or <andrew@databricks.com>

Closes #6171 from andrewor14/dag-viz-skipped and squashes the following commits:

f261797 [Andrew Or] Merge branch 'master' of github.com:apache/spark into dag-viz-skipped
0eda358 [Andrew Or] Tweak skipped stage border color
c604150 [Andrew Or] Tweak grayscale colors
7010676 [Andrew Or] Merge branch 'master' of github.com:apache/spark into dag-viz-skipped
762b541 [Andrew Or] Use special prefix for stage clusters to avoid collisions
51c95b9 [Andrew Or] Merge branch 'master' of github.com:apache/spark into dag-viz-skipped
b928cd4 [Andrew Or] Fix potential leak + write tests for it
7c4c364 [Andrew Or] Show skipped stages differently
7cc34ce [Andrew Or] Merge branch 'master' of github.com:apache/spark into dag-viz-skipped
c121fa2 [Andrew Or] Fix cache color
---
 .../apache/spark/ui/static/spark-dag-viz.css  |  71 +++---
 .../apache/spark/ui/static/spark-dag-viz.js   |  50 ++--
 .../scala/org/apache/spark/ui/UIUtils.scala   |   6 +-
 .../spark/ui/scope/RDDOperationGraph.scala    |  10 +-
 .../ui/scope/RDDOperationGraphListener.scala  |  96 ++++++--
 .../RDDOperationGraphListenerSuite.scala      | 227 ++++++++++++++----
 6 files changed, 352 insertions(+), 108 deletions(-)

diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css
index eedefb44b9..3b4ae2ed35 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css
@@ -15,32 +15,21 @@
  * limitations under the License.
  */
 
-#dag-viz-graph svg path {
-  stroke: #444;
-  stroke-width: 1.5px;
-}
-
-#dag-viz-graph svg g.cluster rect {
-  stroke-width: 1px;
-}
-
-#dag-viz-graph svg g.node circle {
-  fill: #444;
+#dag-viz-graph a, #dag-viz-graph a:hover {
+  text-decoration: none;
 }
 
-#dag-viz-graph svg g.node rect {
-  fill: #C3EBFF;
-  stroke: #3EC0FF;
-  stroke-width: 1px;
+#dag-viz-graph .label {
+  font-weight: normal;
+  text-shadow: none;
 }
 
-#dag-viz-graph svg g.node.cached circle {
-  fill: #444;
+#dag-viz-graph svg path {
+  stroke: #444;
+  stroke-width: 1.5px;
 }
 
-#dag-viz-graph svg g.node.cached rect {
-  fill: #B3F5C5;
-  stroke: #56F578;
+#dag-viz-graph svg g.cluster rect {
   stroke-width: 1px;
 }
 
@@ -61,12 +50,23 @@
   stroke-width: 1px;
 }
 
-#dag-viz-graph svg.job g.cluster[class*="stage"] rect {
+#dag-viz-graph svg.job g.cluster.skipped rect {
+  fill: #D6D6D6;
+  stroke: #B7B7B7;
+  stroke-width: 1px;
+}
+
+#dag-viz-graph svg.job g.cluster.stage rect {
   fill: #FFFFFF;
   stroke: #FF99AC;
   stroke-width: 1px;
 }
 
+#dag-viz-graph svg.job g.cluster.stage.skipped rect {
+  stroke: #ADADAD;
+  stroke-width: 1px;
+}
+
 #dag-viz-graph svg.job g#cross-stage-edges path {
   fill: none;
 }
@@ -75,6 +75,20 @@
   fill: #333;
 }
 
+#dag-viz-graph svg.job g.cluster.skipped text {
+  fill: #666;
+}
+
+#dag-viz-graph svg.job g.node circle {
+  fill: #444;
+}
+
+#dag-viz-graph svg.job g.node.cached circle {
+  fill: #A3F545;
+  stroke: #52C366;
+  stroke-width: 2px;
+}
+
 /* Stage page specific styles */
 
 #dag-viz-graph svg.stage g.cluster rect {
@@ -83,7 +97,7 @@
   stroke-width: 1px;
 }
 
-#dag-viz-graph svg.stage g.cluster[class*="stage"] rect {
+#dag-viz-graph svg.stage g.cluster.stage rect {
   fill: #FFFFFF;
   stroke: #FFA6B6;
   stroke-width: 1px;
@@ -97,11 +111,14 @@
   fill: #333;
 }
 
-#dag-viz-graph a, #dag-viz-graph a:hover {
-  text-decoration: none;
+#dag-viz-graph svg.stage g.node rect {
+  fill: #C3EBFF;
+  stroke: #3EC0FF;
+  stroke-width: 1px;
 }
 
-#dag-viz-graph .label {
-  font-weight: normal;
-  text-shadow: none;
+#dag-viz-graph svg.stage g.node.cached rect {
+  fill: #B3F5C5;
+  stroke: #52C366;
+  stroke-width: 2px;
 }
diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
index ee48fd29a6..aaeba5b102 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
@@ -57,9 +57,7 @@ var VizConstants = {
   stageSep: 40,
   graphPrefix: "graph_",
   nodePrefix: "node_",
-  stagePrefix: "stage_",
-  clusterPrefix: "cluster_",
-  stageClusterPrefix: "cluster_stage_"
+  clusterPrefix: "cluster_"
 };
 
 var JobPageVizConstants = {
@@ -133,9 +131,7 @@ function renderDagViz(forJob) {
   }
 
   // Render
-  var svg = graphContainer()
-    .append("svg")
-    .attr("class", jobOrStage);
+  var svg = graphContainer().append("svg").attr("class", jobOrStage);
   if (forJob) {
     renderDagVizForJob(svg);
   } else {
@@ -185,23 +181,32 @@ function renderDagVizForJob(svgContainer) {
     var dot = metadata.select(".dot-file").text();
     var stageId = metadata.attr("stage-id");
     var containerId = VizConstants.graphPrefix + stageId;
-    // Link each graph to the corresponding stage page (TODO: handle stage attempts)
-    var stageLink = $("#stage-" + stageId.replace(VizConstants.stagePrefix, "") + "-0")
-      .find("a")
-      .attr("href") + "&expandDagViz=true";
-    var container = svgContainer
-      .append("a")
-      .attr("xlink:href", stageLink)
-      .append("g")
-      .attr("id", containerId);
+    var isSkipped = metadata.attr("skipped") == "true";
+    var container;
+    if (isSkipped) {
+      container = svgContainer
+        .append("g")
+        .attr("id", containerId)
+        .attr("skipped", "true");
+    } else {
+      // Link each graph to the corresponding stage page (TODO: handle stage attempts)
+      // Use the link from the stage table so it also works for the history server
+      var attemptId = 0
+      var stageLink = d3.select("#stage-" + stageId + "-" + attemptId)
+        .select("a")
+        .attr("href") + "&expandDagViz=true";
+      container = svgContainer
+        .append("a")
+        .attr("xlink:href", stageLink)
+        .append("g")
+        .attr("id", containerId);
+    }
 
     // Now we need to shift the container for this stage so it doesn't overlap with
     // existing ones, taking into account the position and width of the last stage's
     // container. We do not need to do this for the first stage of this job.
     if (i > 0) {
-      var existingStages = svgContainer
-        .selectAll("g.cluster")
-        .filter("[class*=\"" + VizConstants.stageClusterPrefix + "\"]");
+      var existingStages = svgContainer.selectAll("g.cluster.stage")
       if (!existingStages.empty()) {
         var lastStage = d3.select(existingStages[0].pop());
         var lastStageWidth = toFloat(lastStage.select("rect").attr("width"));
@@ -214,6 +219,12 @@ function renderDagVizForJob(svgContainer) {
     // Actually render the stage
     renderDot(dot, container, true);
 
+    // Mark elements as skipped if appropriate. Unfortunately we need to mark all
+    // elements instead of the parent container because of CSS override rules.
+    if (isSkipped) {
+      container.selectAll("g").classed("skipped", true);
+    }
+
     // Round corners on rectangles
     container
       .selectAll("rect")
@@ -243,6 +254,9 @@ function renderDot(dot, container, forJob) {
   var renderer = new dagreD3.render();
   preprocessGraphLayout(g, forJob);
   renderer(container, g);
+
+  // Find the stage cluster and mark it for styling and post-processing
+  container.selectAll("g.cluster[name*=\"Stage\"]").classed("stage", true);
 }
 
 /* -------------------- *
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index ad16becde8..6194c50ec8 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -352,10 +352,12 @@ private[spark] object UIUtils extends Logging {
         </a>
       </span>
       <div id="dag-viz-graph"></div>
-      <div id="dag-viz-metadata">
+      <div id="dag-viz-metadata" style="display:none">
         {
           graphs.map { g =>
-            <div class="stage-metadata" stage-id={g.rootCluster.id} style="display:none">
+            val stageId = g.rootCluster.id.replaceAll(RDDOperationGraph.STAGE_CLUSTER_PREFIX, "")
+            val skipped = g.rootCluster.name.contains("skipped").toString
+            <div class="stage-metadata" stage-id={stageId} skipped={skipped}>
               <div class="dot-file">{RDDOperationGraph.makeDotFile(g)}</div>
               { g.incomingEdges.map { e => <div class="incoming-edge">{e.fromId},{e.toId}</div> } }
               { g.outgoingEdges.map { e => <div class="outgoing-edge">{e.fromId},{e.toId}</div> } }
diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
index 25d5c6ff7e..33a7303be7 100644
--- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
+++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
@@ -52,10 +52,13 @@ private[ui] case class RDDOperationEdge(fromId: Int, toId: Int)
  * This represents any grouping of RDDs, including operation scopes (e.g. textFile, flatMap),
  * stages, jobs, or any higher level construct. A cluster may be nested inside of other clusters.
  */
-private[ui] class RDDOperationCluster(val id: String, val name: String) {
+private[ui] class RDDOperationCluster(val id: String, private var _name: String) {
   private val _childNodes = new ListBuffer[RDDOperationNode]
   private val _childClusters = new ListBuffer[RDDOperationCluster]
 
+  def name: String = _name
+  def setName(n: String): Unit = { _name = n }
+
   def childNodes: Seq[RDDOperationNode] = _childNodes.iterator.toSeq
   def childClusters: Seq[RDDOperationCluster] = _childClusters.iterator.toSeq
   def attachChildNode(childNode: RDDOperationNode): Unit = { _childNodes += childNode }
@@ -71,6 +74,8 @@ private[ui] class RDDOperationCluster(val id: String, val name: String) {
 
 private[ui] object RDDOperationGraph extends Logging {
 
+  val STAGE_CLUSTER_PREFIX = "stage_"
+
   /**
    * Construct a RDDOperationGraph for a given stage.
    *
@@ -88,7 +93,8 @@ private[ui] object RDDOperationGraph extends Logging {
     val clusters = new mutable.HashMap[String, RDDOperationCluster] // indexed by cluster ID
 
     // Root cluster is the stage cluster
-    val stageClusterId = s"stage_${stage.stageId}"
+    // Use a special prefix here to differentiate this cluster from other operation clusters
+    val stageClusterId = STAGE_CLUSTER_PREFIX + stage.stageId
     val stageClusterName = s"Stage ${stage.stageId}" +
       { if (stage.attemptId == 0) "" else s" (attempt ${stage.attemptId})" }
     val rootCluster = new RDDOperationCluster(stageClusterId, stageClusterName)
diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala
index aa9c25cb5c..89119cd357 100644
--- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala
@@ -27,8 +27,15 @@ import org.apache.spark.ui.SparkUI
  * A SparkListener that constructs a DAG of RDD operations.
  */
 private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListener {
+
+  // Note: the fate of jobs and stages are tied. This means when we clean up a job,
+  // we always clean up all of its stages. Similarly, when we clean up a stage, we
+  // always clean up its job (and, transitively, other stages in the same job).
   private[ui] val jobIdToStageIds = new mutable.HashMap[Int, Seq[Int]]
+  private[ui] val jobIdToSkippedStageIds = new mutable.HashMap[Int, Seq[Int]]
+  private[ui] val stageIdToJobId = new mutable.HashMap[Int, Int]
   private[ui] val stageIdToGraph = new mutable.HashMap[Int, RDDOperationGraph]
+  private[ui] val completedStageIds = new mutable.HashSet[Int]
 
   // Keep track of the order in which these are inserted so we can remove old ones
   private[ui] val jobIds = new mutable.ArrayBuffer[Int]
@@ -40,16 +47,23 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen
   private val retainedStages =
     conf.getInt("spark.ui.retainedStages", SparkUI.DEFAULT_RETAINED_STAGES)
 
-  /** Return the graph metadata for the given stage, or None if no such information exists. */
+  /**
+   * Return the graph metadata for all stages in the given job.
+   * An empty list is returned if one or more of its stages has been cleaned up.
+   */
   def getOperationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = synchronized {
-    val _stageIds = jobIdToStageIds.get(jobId).getOrElse { Seq.empty }
-    val graphs = _stageIds.flatMap { sid => stageIdToGraph.get(sid) }
-    // If the metadata for some stages have been removed, do not bother rendering this job
-    if (_stageIds.size != graphs.size) {
-      Seq.empty
-    } else {
-      graphs
+    val skippedStageIds = jobIdToSkippedStageIds.get(jobId).getOrElse(Seq.empty)
+    val graphs = jobIdToStageIds.get(jobId)
+      .getOrElse(Seq.empty)
+      .flatMap { sid => stageIdToGraph.get(sid) }
+    // Mark any skipped stages as such
+    graphs.foreach { g =>
+      val stageId = g.rootCluster.id.replaceAll(RDDOperationGraph.STAGE_CLUSTER_PREFIX, "").toInt
+      if (skippedStageIds.contains(stageId) && !g.rootCluster.name.contains("skipped")) {
+        g.rootCluster.setName(g.rootCluster.name + " (skipped)")
+      }
     }
+    graphs
   }
 
   /** Return the graph metadata for the given stage, or None if no such information exists. */
@@ -66,22 +80,68 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen
     jobIdToStageIds(jobId) = jobStart.stageInfos.map(_.stageId).sorted
 
     stageInfos.foreach { stageInfo =>
-      stageIds += stageInfo.stageId
-      stageIdToGraph(stageInfo.stageId) = RDDOperationGraph.makeOperationGraph(stageInfo)
-      // Remove state for old stages
-      if (stageIds.size >= retainedStages) {
-        val toRemove = math.max(retainedStages / 10, 1)
-        stageIds.take(toRemove).foreach { id => stageIdToGraph.remove(id) }
-        stageIds.trimStart(toRemove)
-      }
+      val stageId = stageInfo.stageId
+      stageIds += stageId
+      stageIdToJobId(stageId) = jobId
+      stageIdToGraph(stageId) = RDDOperationGraph.makeOperationGraph(stageInfo)
+      trimStagesIfNecessary()
+    }
+
+    trimJobsIfNecessary()
+  }
+
+  /** Keep track of stages that have completed. */
+  override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized {
+    val stageId = stageCompleted.stageInfo.stageId
+    if (stageIdToJobId.contains(stageId)) {
+      // Note: Only do this if the stage has not already been cleaned up
+      // Otherwise, we may never clean this stage from `completedStageIds`
+      completedStageIds += stageCompleted.stageInfo.stageId
+    }
+  }
+
+  /** On job end, find all stages in this job that are skipped and mark them as such. */
+  override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized {
+    val jobId = jobEnd.jobId
+    jobIdToStageIds.get(jobId).foreach { stageIds =>
+      val skippedStageIds = stageIds.filter { sid => !completedStageIds.contains(sid) }
+      // Note: Only do this if the job has not already been cleaned up
+      // Otherwise, we may never clean this job from `jobIdToSkippedStageIds`
+      jobIdToSkippedStageIds(jobId) = skippedStageIds
     }
+  }
+
+  /** Clean metadata for old stages if we have exceeded the number to retain. */
+  private def trimStagesIfNecessary(): Unit = {
+    if (stageIds.size >= retainedStages) {
+      val toRemove = math.max(retainedStages / 10, 1)
+      stageIds.take(toRemove).foreach { id => cleanStage(id) }
+      stageIds.trimStart(toRemove)
+    }
+  }
 
-    // Remove state for old jobs
+  /** Clean metadata for old jobs if we have exceeded the number to retain. */
+  private def trimJobsIfNecessary(): Unit = {
     if (jobIds.size >= retainedJobs) {
       val toRemove = math.max(retainedJobs / 10, 1)
-      jobIds.take(toRemove).foreach { id => jobIdToStageIds.remove(id) }
+      jobIds.take(toRemove).foreach { id => cleanJob(id) }
       jobIds.trimStart(toRemove)
     }
   }
 
+  /** Clean metadata for the given stage, its job, and all other stages that belong to the job. */
+  private[ui] def cleanStage(stageId: Int): Unit = {
+    completedStageIds.remove(stageId)
+    stageIdToGraph.remove(stageId)
+    stageIdToJobId.remove(stageId).foreach { jobId => cleanJob(jobId) }
+  }
+
+  /** Clean metadata for the given job and all stages that belong to it. */
+  private[ui] def cleanJob(jobId: Int): Unit = {
+    jobIdToSkippedStageIds.remove(jobId)
+    jobIdToStageIds.remove(jobId).foreach { stageIds =>
+      stageIds.foreach { stageId => cleanStage(stageId) }
+    }
+  }
+
 }
diff --git a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala
index c659fc1e8b..c1126f3af5 100644
--- a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala
@@ -20,67 +20,212 @@ package org.apache.spark.ui.scope
 import org.scalatest.FunSuite
 
 import org.apache.spark.SparkConf
-import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListenerStageSubmitted, StageInfo}
+import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.SparkListenerStageSubmitted
+import org.apache.spark.scheduler.SparkListenerStageCompleted
+import org.apache.spark.scheduler.SparkListenerJobStart
 
+/**
+ * Tests that this listener populates and cleans up its data structures properly.
+ */
 class RDDOperationGraphListenerSuite extends FunSuite {
   private var jobIdCounter = 0
   private var stageIdCounter = 0
+  private val maxRetainedJobs = 10
+  private val maxRetainedStages = 10
+  private val conf = new SparkConf()
+    .set("spark.ui.retainedJobs", maxRetainedJobs.toString)
+    .set("spark.ui.retainedStages", maxRetainedStages.toString)
 
-  /** Run a job with the specified number of stages. */
-  private def runOneJob(numStages: Int, listener: RDDOperationGraphListener): Unit = {
-    assert(numStages > 0, "I will not run a job with 0 stages for you.")
-    val stageInfos = (0 until numStages).map { _ =>
-      val stageInfo = new StageInfo(stageIdCounter, 0, "s", 0, Seq.empty, Seq.empty, "d")
-      stageIdCounter += 1
-      stageInfo
-    }
-    listener.onJobStart(new SparkListenerJobStart(jobIdCounter, 0, stageInfos))
-    jobIdCounter += 1
-  }
-
-  test("listener cleans up metadata") {
-
-    val conf = new SparkConf()
-      .set("spark.ui.retainedStages", "10")
-      .set("spark.ui.retainedJobs", "10")
-
+  test("run normal jobs") {
+    val startingJobId = jobIdCounter
+    val startingStageId = stageIdCounter
     val listener = new RDDOperationGraphListener(conf)
     assert(listener.jobIdToStageIds.isEmpty)
+    assert(listener.jobIdToSkippedStageIds.isEmpty)
+    assert(listener.stageIdToJobId.isEmpty)
     assert(listener.stageIdToGraph.isEmpty)
+    assert(listener.completedStageIds.isEmpty)
     assert(listener.jobIds.isEmpty)
     assert(listener.stageIds.isEmpty)
 
     // Run a few jobs, but not enough for clean up yet
-    runOneJob(1, listener)
-    runOneJob(2, listener)
-    runOneJob(3, listener)
+    (1 to 3).foreach { numStages => startJob(numStages, listener) } // start 3 jobs and 6 stages
+    (0 to 5).foreach { i => endStage(startingStageId + i, listener) } // finish all 6 stages
+    (0 to 2).foreach { i => endJob(startingJobId + i, listener) } // finish all 3 jobs
+
     assert(listener.jobIdToStageIds.size === 3)
+    assert(listener.jobIdToStageIds(startingJobId).size === 1)
+    assert(listener.jobIdToStageIds(startingJobId + 1).size === 2)
+    assert(listener.jobIdToStageIds(startingJobId + 2).size === 3)
+    assert(listener.jobIdToSkippedStageIds.size === 3)
+    assert(listener.jobIdToSkippedStageIds.values.forall(_.isEmpty)) // no skipped stages
+    assert(listener.stageIdToJobId.size === 6)
+    assert(listener.stageIdToJobId(startingStageId) === startingJobId)
+    assert(listener.stageIdToJobId(startingStageId + 1) === startingJobId + 1)
+    assert(listener.stageIdToJobId(startingStageId + 2) === startingJobId + 1)
+    assert(listener.stageIdToJobId(startingStageId + 3) === startingJobId + 2)
+    assert(listener.stageIdToJobId(startingStageId + 4) === startingJobId + 2)
+    assert(listener.stageIdToJobId(startingStageId + 5) === startingJobId + 2)
     assert(listener.stageIdToGraph.size === 6)
+    assert(listener.completedStageIds.size === 6)
     assert(listener.jobIds.size === 3)
     assert(listener.stageIds.size === 6)
+  }
+
+  test("run jobs with skipped stages") {
+    val startingJobId = jobIdCounter
+    val startingStageId = stageIdCounter
+    val listener = new RDDOperationGraphListener(conf)
+
+    // Run a few jobs, but not enough for clean up yet
+    // Leave some stages unfinished so that they are marked as skipped
+    (1 to 3).foreach { numStages => startJob(numStages, listener) } // start 3 jobs and 6 stages
+    (4 to 5).foreach { i => endStage(startingStageId + i, listener) } // finish only last 2 stages
+    (0 to 2).foreach { i => endJob(startingJobId + i, listener) } // finish all 3 jobs
+
+    assert(listener.jobIdToSkippedStageIds.size === 3)
+    assert(listener.jobIdToSkippedStageIds(startingJobId).size === 1)
+    assert(listener.jobIdToSkippedStageIds(startingJobId + 1).size === 2)
+    assert(listener.jobIdToSkippedStageIds(startingJobId + 2).size === 1) // 2 stages not skipped
+    assert(listener.completedStageIds.size === 2)
+
+    // The rest should be the same as before
+    assert(listener.jobIdToStageIds.size === 3)
+    assert(listener.jobIdToStageIds(startingJobId).size === 1)
+    assert(listener.jobIdToStageIds(startingJobId + 1).size === 2)
+    assert(listener.jobIdToStageIds(startingJobId + 2).size === 3)
+    assert(listener.stageIdToJobId.size === 6)
+    assert(listener.stageIdToJobId(startingStageId) === startingJobId)
+    assert(listener.stageIdToJobId(startingStageId + 1) === startingJobId + 1)
+    assert(listener.stageIdToJobId(startingStageId + 2) === startingJobId + 1)
+    assert(listener.stageIdToJobId(startingStageId + 3) === startingJobId + 2)
+    assert(listener.stageIdToJobId(startingStageId + 4) === startingJobId + 2)
+    assert(listener.stageIdToJobId(startingStageId + 5) === startingJobId + 2)
+    assert(listener.stageIdToGraph.size === 6)
+    assert(listener.jobIds.size === 3)
+    assert(listener.stageIds.size === 6)
+  }
+
+  test("clean up metadata") {
+    val startingJobId = jobIdCounter
+    val startingStageId = stageIdCounter
+    val listener = new RDDOperationGraphListener(conf)
 
-    // Run a few more, but this time the stages should be cleaned up, but not the jobs
-    runOneJob(5, listener)
-    runOneJob(100, listener)
-    assert(listener.jobIdToStageIds.size === 5)
-    assert(listener.stageIdToGraph.size === 9)
-    assert(listener.jobIds.size === 5)
-    assert(listener.stageIds.size === 9)
-
-    // Run a few more, but this time both jobs and stages should be cleaned up
-    (1 to 100).foreach { _ =>
-      runOneJob(1, listener)
+    // Run many jobs and stages to trigger clean up
+    (1 to 10000).foreach { i =>
+      // Note: this must be less than `maxRetainedStages`
+      val numStages = i % (maxRetainedStages - 2) + 1
+      val startingStageIdForJob = stageIdCounter
+      val jobId = startJob(numStages, listener)
+      // End some, but not all, stages that belong to this job
+      // This is to ensure that we have both completed and skipped stages
+      (startingStageIdForJob until stageIdCounter)
+        .filter { i => i % 2 == 0 }
+        .foreach { i => endStage(i, listener) }
+      // End all jobs
+      endJob(jobId, listener)
     }
-    assert(listener.jobIdToStageIds.size === 9)
-    assert(listener.stageIdToGraph.size === 9)
-    assert(listener.jobIds.size === 9)
-    assert(listener.stageIds.size === 9)
+
+    // Ensure we never exceed the max retained thresholds
+    assert(listener.jobIdToStageIds.size <= maxRetainedJobs)
+    assert(listener.jobIdToSkippedStageIds.size <= maxRetainedJobs)
+    assert(listener.stageIdToJobId.size <= maxRetainedStages)
+    assert(listener.stageIdToGraph.size <= maxRetainedStages)
+    assert(listener.completedStageIds.size <= maxRetainedStages)
+    assert(listener.jobIds.size <= maxRetainedJobs)
+    assert(listener.stageIds.size <= maxRetainedStages)
+
+    // Also ensure we're actually populating these data structures
+    // Otherwise the previous group of asserts will be meaningless
+    assert(listener.jobIdToStageIds.nonEmpty)
+    assert(listener.jobIdToSkippedStageIds.nonEmpty)
+    assert(listener.stageIdToJobId.nonEmpty)
+    assert(listener.stageIdToGraph.nonEmpty)
+    assert(listener.completedStageIds.nonEmpty)
+    assert(listener.jobIds.nonEmpty)
+    assert(listener.stageIds.nonEmpty)
 
     // Ensure we clean up old jobs and stages, not arbitrary ones
-    assert(!listener.jobIdToStageIds.contains(0))
-    assert(!listener.stageIdToGraph.contains(0))
-    assert(!listener.stageIds.contains(0))
-    assert(!listener.jobIds.contains(0))
+    assert(!listener.jobIdToStageIds.contains(startingJobId))
+    assert(!listener.jobIdToSkippedStageIds.contains(startingJobId))
+    assert(!listener.stageIdToJobId.contains(startingStageId))
+    assert(!listener.stageIdToGraph.contains(startingStageId))
+    assert(!listener.completedStageIds.contains(startingStageId))
+    assert(!listener.stageIds.contains(startingStageId))
+    assert(!listener.jobIds.contains(startingJobId))
+  }
+
+  test("fate sharing between jobs and stages") {
+    val startingJobId = jobIdCounter
+    val startingStageId = stageIdCounter
+    val listener = new RDDOperationGraphListener(conf)
+
+    // Run 3 jobs and 8 stages, finishing all 3 jobs but only 2 stages
+    startJob(5, listener)
+    startJob(1, listener)
+    startJob(2, listener)
+    (0 until 8).foreach { i => startStage(i + startingStageId, listener) }
+    endStage(startingStageId + 3, listener)
+    endStage(startingStageId + 4, listener)
+    (0 until 3).foreach { i => endJob(i + startingJobId, listener) }
+
+    // First, assert the old stuff
+    assert(listener.jobIdToStageIds.size === 3)
+    assert(listener.jobIdToSkippedStageIds.size === 3)
+    assert(listener.stageIdToJobId.size === 8)
+    assert(listener.stageIdToGraph.size === 8)
+    assert(listener.completedStageIds.size === 2)
+
+    // Cleaning the third job should clean all of its stages
+    listener.cleanJob(startingJobId + 2)
+    assert(listener.jobIdToStageIds.size === 2)
+    assert(listener.jobIdToSkippedStageIds.size === 2)
+    assert(listener.stageIdToJobId.size === 6)
+    assert(listener.stageIdToGraph.size === 6)
+    assert(listener.completedStageIds.size === 2)
+
+    // Cleaning one of the stages in the first job should clean that job and all of its stages
+    // Note that we still keep around the last stage because it belongs to a different job
+    listener.cleanStage(startingStageId)
+    assert(listener.jobIdToStageIds.size === 1)
+    assert(listener.jobIdToSkippedStageIds.size === 1)
+    assert(listener.stageIdToJobId.size === 1)
+    assert(listener.stageIdToGraph.size === 1)
+    assert(listener.completedStageIds.size === 0)
+  }
+
+  /** Start a job with the specified number of stages. */
+  private def startJob(numStages: Int, listener: RDDOperationGraphListener): Int = {
+    assert(numStages > 0, "I will not run a job with 0 stages for you.")
+    val stageInfos = (0 until numStages).map { _ =>
+      val stageInfo = new StageInfo(stageIdCounter, 0, "s", 0, Seq.empty, Seq.empty, "d")
+      stageIdCounter += 1
+      stageInfo
+    }
+    val jobId = jobIdCounter
+    listener.onJobStart(new SparkListenerJobStart(jobId, 0, stageInfos))
+    // Also start all stages that belong to this job
+    stageInfos.map(_.stageId).foreach { sid => startStage(sid, listener) }
+    jobIdCounter += 1
+    jobId
+  }
+
+  /** Start the stage specified by the given ID. */
+  private def startStage(stageId: Int, listener: RDDOperationGraphListener): Unit = {
+    val stageInfo = new StageInfo(stageId, 0, "s", 0, Seq.empty, Seq.empty, "d")
+    listener.onStageSubmitted(new SparkListenerStageSubmitted(stageInfo))
+  }
+
+  /** Finish the stage specified by the given ID. */
+  private def endStage(stageId: Int, listener: RDDOperationGraphListener): Unit = {
+    val stageInfo = new StageInfo(stageId, 0, "s", 0, Seq.empty, Seq.empty, "d")
+    listener.onStageCompleted(new SparkListenerStageCompleted(stageInfo))
+  }
+
+  /** Finish the job specified by the given ID. */
+  private def endJob(jobId: Int, listener: RDDOperationGraphListener): Unit = {
+    listener.onJobEnd(new SparkListenerJobEnd(jobId, 0, JobSucceeded))
   }
 
 }
-- 
GitLab