From 57446eb69ceb6b8856ab22b54abb22b47b80f841 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@databricks.com>
Date: Tue, 3 Nov 2015 07:06:00 -0800
Subject: [PATCH] [SPARK-11256] Mark all Stage/ResultStage/ShuffleMapStage
 internal state as private.

Author: Reynold Xin <rxin@databricks.com>

Closes #9219 from rxin/stage-cleanup1.
---
 .../apache/spark/scheduler/DAGScheduler.scala | 33 +++++-----
 .../apache/spark/scheduler/ResultStage.scala  | 19 +++++-
 .../spark/scheduler/ShuffleMapStage.scala     | 61 +++++++++++++------
 .../org/apache/spark/scheduler/Stage.scala    |  5 +-
 4 files changed, 80 insertions(+), 38 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 995862ece5..5673fbf2c8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -23,7 +23,7 @@ import java.util.concurrent.TimeUnit
 import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.Map
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack}
+import scala.collection.mutable.{HashMap, HashSet, Stack}
 import scala.concurrent.duration._
 import scala.language.existentials
 import scala.language.postfixOps
@@ -535,10 +535,8 @@ class DAGScheduler(
     jobIdToActiveJob -= job.jobId
     activeJobs -= job
     job.finalStage match {
-      case r: ResultStage =>
-        r.resultOfJob = None
-      case m: ShuffleMapStage =>
-        m.mapStageJobs = m.mapStageJobs.filter(_ != job)
+      case r: ResultStage => r.removeActiveJob()
+      case m: ShuffleMapStage => m.removeActiveJob(job)
     }
   }
 
@@ -848,7 +846,7 @@ class DAGScheduler(
     val jobSubmissionTime = clock.getTimeMillis()
     jobIdToActiveJob(jobId) = job
     activeJobs += job
-    finalStage.resultOfJob = Some(job)
+    finalStage.setActiveJob(job)
     val stageIds = jobIdToStageIds(jobId).toArray
     val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
     listenerBus.post(
@@ -880,7 +878,7 @@ class DAGScheduler(
     val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)
     clearCacheLocs()
     logInfo("Got map stage job %s (%s) with %d output partitions".format(
-      jobId, callSite.shortForm, dependency.rdd.partitions.size))
+      jobId, callSite.shortForm, dependency.rdd.partitions.length))
     logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")")
     logInfo("Parents of final stage: " + finalStage.parents)
     logInfo("Missing parents: " + getMissingParentStages(finalStage))
@@ -888,7 +886,7 @@ class DAGScheduler(
     val jobSubmissionTime = clock.getTimeMillis()
     jobIdToActiveJob(jobId) = job
     activeJobs += job
-    finalStage.mapStageJobs = job :: finalStage.mapStageJobs
+    finalStage.addActiveJob(job)
     val stageIds = jobIdToStageIds(jobId).toArray
     val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
     listenerBus.post(
@@ -950,12 +948,12 @@ class DAGScheduler(
     // will be posted, which should always come after a corresponding SparkListenerStageSubmitted
     // event.
     outputCommitCoordinator.stageStart(stage.id)
-    val taskIdToLocations = try {
+    val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
       stage match {
         case s: ShuffleMapStage =>
           partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
         case s: ResultStage =>
-          val job = s.resultOfJob.get
+          val job = s.activeJob.get
           partitionsToCompute.map { id =>
             val p = s.partitions(id)
             (id, getPreferredLocs(stage.rdd, p))
@@ -1016,7 +1014,7 @@ class DAGScheduler(
           }
 
         case stage: ResultStage =>
-          val job = stage.resultOfJob.get
+          val job = stage.activeJob.get
           partitionsToCompute.map { id =>
             val p: Int = stage.partitions(id)
             val part = stage.rdd.partitions(p)
@@ -1132,7 +1130,7 @@ class DAGScheduler(
             // Cast to ResultStage here because it's part of the ResultTask
             // TODO Refactor this out to a function that accepts a ResultStage
             val resultStage = stage.asInstanceOf[ResultStage]
-            resultStage.resultOfJob match {
+            resultStage.activeJob match {
               case Some(job) =>
                 if (!job.finished(rt.outputId)) {
                   updateAccumulators(event)
@@ -1187,7 +1185,7 @@ class DAGScheduler(
               //       we registered these map outputs.
               mapOutputTracker.registerMapOutputs(
                 shuffleStage.shuffleDep.shuffleId,
-                shuffleStage.outputLocs.map(_.headOption.orNull),
+                shuffleStage.outputLocInMapOutputTrackerFormat(),
                 changeEpoch = true)
 
               clearCacheLocs()
@@ -1197,8 +1195,7 @@ class DAGScheduler(
                 // TODO: Lower-level scheduler should also deal with this
                 logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name +
                   ") because some of its tasks had failed: " +
-                  shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty)
-                    .map(_._2).mkString(", "))
+                  shuffleStage.findMissingPartitions().mkString(", "))
                 submitStage(shuffleStage)
               } else {
                 // Mark any map-stage jobs waiting on this stage as finished
@@ -1312,8 +1309,10 @@ class DAGScheduler(
         // TODO: This will be really slow if we keep accumulating shuffle map stages
         for ((shuffleId, stage) <- shuffleToMapStage) {
           stage.removeOutputsOnExecutor(execId)
-          val locs = stage.outputLocs.map(_.headOption.orNull)
-          mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true)
+          mapOutputTracker.registerMapOutputs(
+            shuffleId,
+            stage.outputLocInMapOutputTrackerFormat(),
+            changeEpoch = true)
         }
         if (shuffleToMapStage.isEmpty) {
           mapOutputTracker.incrementEpoch()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
index c1d86af7e8..d1687830ff 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
@@ -41,10 +41,25 @@ private[spark] class ResultStage(
    * The active job for this result stage. Will be empty if the job has already finished
    * (e.g., because the job was cancelled).
    */
-  var resultOfJob: Option[ActiveJob] = None
+  private[this] var _activeJob: Option[ActiveJob] = None
 
+  def activeJob: Option[ActiveJob] = _activeJob
+
+  def setActiveJob(job: ActiveJob): Unit = {
+    _activeJob = Option(job)
+  }
+
+  def removeActiveJob(): Unit = {
+    _activeJob = None
+  }
+
+  /**
+   * Returns the sequence of partition ids that are missing (i.e. needs to be computed).
+   *
+   * This can only be called when there is an active job.
+   */
   override def findMissingPartitions(): Seq[Int] = {
-    val job = resultOfJob.get
+    val job = activeJob.get
     (0 until job.numPartitions).filter(id => !job.finished(id))
   }
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
index 3832d99edd..51416e5ce9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
@@ -43,35 +43,53 @@ private[spark] class ShuffleMapStage(
     val shuffleDep: ShuffleDependency[_, _, _])
   extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) {
 
+  private[this] var _mapStageJobs: List[ActiveJob] = Nil
+
+  private[this] var _numAvailableOutputs: Int = 0
+
+  /**
+   * List of [[MapStatus]] for each partition. The index of the array is the map partition id,
+   * and each value in the array is the list of possible [[MapStatus]] for a partition
+   * (a single task might run multiple times).
+   */
+  private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
+
   override def toString: String = "ShuffleMapStage " + id
 
-  /** Running map-stage jobs that were submitted to execute this stage independently (if any) */
-  var mapStageJobs: List[ActiveJob] = Nil
+  /**
+   * Returns the list of active jobs,
+   * i.e. map-stage jobs that were submitted to execute this stage independently (if any).
+   */
+  def mapStageJobs: Seq[ActiveJob] = _mapStageJobs
+
+  /** Adds the job to the active job list. */
+  def addActiveJob(job: ActiveJob): Unit = {
+    _mapStageJobs = job :: _mapStageJobs
+  }
+
+  /** Removes the job from the active job list. */
+  def removeActiveJob(job: ActiveJob): Unit = {
+    _mapStageJobs = _mapStageJobs.filter(_ != job)
+  }
 
   /**
    * Number of partitions that have shuffle outputs.
    * When this reaches [[numPartitions]], this map stage is ready.
    * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`.
    */
-  var numAvailableOutputs: Int = 0
+  def numAvailableOutputs: Int = _numAvailableOutputs
 
   /**
    * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs.
    * This should be the same as `outputLocs.contains(Nil)`.
    */
-  def isAvailable: Boolean = numAvailableOutputs == numPartitions
-
-  /**
-   * List of [[MapStatus]] for each partition. The index of the array is the map partition id,
-   * and each value in the array is the list of possible [[MapStatus]] for a partition
-   * (a single task might run multiple times).
-   */
-  val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
+  def isAvailable: Boolean = _numAvailableOutputs == numPartitions
 
+  /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */
   override def findMissingPartitions(): Seq[Int] = {
     val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty)
-    assert(missing.size == numPartitions - numAvailableOutputs,
-      s"${missing.size} missing, expected ${numPartitions - numAvailableOutputs}")
+    assert(missing.size == numPartitions - _numAvailableOutputs,
+      s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
     missing
   }
 
@@ -79,7 +97,7 @@ private[spark] class ShuffleMapStage(
     val prevList = outputLocs(partition)
     outputLocs(partition) = status :: prevList
     if (prevList == Nil) {
-      numAvailableOutputs += 1
+      _numAvailableOutputs += 1
     }
   }
 
@@ -88,10 +106,19 @@ private[spark] class ShuffleMapStage(
     val newList = prevList.filterNot(_.location == bmAddress)
     outputLocs(partition) = newList
     if (prevList != Nil && newList == Nil) {
-      numAvailableOutputs -= 1
+      _numAvailableOutputs -= 1
     }
   }
 
+  /**
+   * Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned
+   * value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition,
+   * that position is filled with null.
+   */
+  def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = {
+    outputLocs.map(_.headOption.orNull)
+  }
+
   /**
    * Removes all shuffle outputs associated with this executor. Note that this will also remove
    * outputs which are served by an external shuffle server (if one exists), as they are still
@@ -105,12 +132,12 @@ private[spark] class ShuffleMapStage(
       outputLocs(partition) = newList
       if (prevList != Nil && newList == Nil) {
         becameUnavailable = true
-        numAvailableOutputs -= 1
+        _numAvailableOutputs -= 1
       }
     }
     if (becameUnavailable) {
       logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format(
-        this, execId, numAvailableOutputs, numPartitions, isAvailable))
+        this, execId, _numAvailableOutputs, numPartitions, isAvailable))
     }
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 5ce4a48434..7ea24a217b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -71,8 +71,8 @@ private[scheduler] abstract class Stage(
   /** The ID to use for the next new attempt for this stage. */
   private var nextAttemptId: Int = 0
 
-  val name = callSite.shortForm
-  val details = callSite.longForm
+  val name: String = callSite.shortForm
+  val details: String = callSite.longForm
 
   private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty
 
@@ -134,6 +134,7 @@ private[scheduler] abstract class Stage(
   def latestInfo: StageInfo = _latestInfo
 
   override final def hashCode(): Int = id
+
   override final def equals(other: Any): Boolean = other match {
     case stage: Stage => stage != null && stage.id == id
     case _ => false
-- 
GitLab