From 5a93e3c58c69574eaac6458f8515579b5bd03fd9 Mon Sep 17 00:00:00 2001
From: Karen Feng <karenfeng.us@gmail.com>
Date: Sat, 27 Jul 2013 15:55:26 -0700
Subject: [PATCH] Cleaned up code based on pwendell's suggestions

---
 .../scala/spark/ui/exec/ExecutorsUI.scala     | 12 ++----
 .../main/scala/spark/ui/jobs/IndexPage.scala  | 18 ++++----
 .../scala/spark/ui/jobs/JobProgressUI.scala   | 43 ++++---------------
 3 files changed, 20 insertions(+), 53 deletions(-)

diff --git a/core/src/main/scala/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/spark/ui/exec/ExecutorsUI.scala
index 80d00c6873..948b3017db 100644
--- a/core/src/main/scala/spark/ui/exec/ExecutorsUI.scala
+++ b/core/src/main/scala/spark/ui/exec/ExecutorsUI.scala
@@ -121,10 +121,8 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
 
     override def onTaskStart(taskStart: SparkListenerTaskStart) {
       val eid = taskStart.taskInfo.executorId
-      if (!executorToTasksActive.contains(eid)) {
-        executorToTasksActive(eid) = HashSet[TaskInfo]()
-      }
-      executorToTasksActive(eid) += taskStart.taskInfo
+      val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]())
+      activeTasks += taskStart.taskInfo
       val taskList = executorToTaskInfos.getOrElse(
         eid, ArrayBuffer[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
       taskList += ((taskStart.taskInfo, None, None))
@@ -133,10 +131,8 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
 
     override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
       val eid = taskEnd.taskInfo.executorId
-      if (!executorToTasksActive.contains(eid)) {
-        executorToTasksActive(eid) = HashSet[TaskInfo]()
-      }
-      executorToTasksActive(eid) -= taskEnd.taskInfo
+      val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]())
+      activeTasks -= taskStart.taskInfo
       val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
         taskEnd.reason match {
           case e: ExceptionFailure =>
diff --git a/core/src/main/scala/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/spark/ui/jobs/IndexPage.scala
index b862c3539c..7e504c5f9f 100644
--- a/core/src/main/scala/spark/ui/jobs/IndexPage.scala
+++ b/core/src/main/scala/spark/ui/jobs/IndexPage.scala
@@ -125,16 +125,14 @@ private[spark] class IndexPage(parent: JobProgressUI) {
       case None => "Unknown"
     }
 
-    val shuffleRead =
-      if (!listener.hasShuffleRead(s.id))
-        ""
-      else
-        Utils.memoryBytesToString(listener.stageToShuffleRead(s.id))
-    val shuffleWrite =
-      if (!listener.hasShuffleWrite(s.id))
-        ""
-      else
-        Utils.memoryBytesToString(listener.stageToShuffleWrite(s.id))
+    val shuffleRead = listener.stageToShuffleRead(s.id) match {
+      case 0 => ""
+      case b => Utils.memoryBytesToString(b)
+    }
+    val shuffleWrite = listener.stageToShuffleWrite(s.id) match {
+      case 0 => ""
+      case b => Utils.memoryBytesToString(b)
+    }
 
     val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0)
     val totalTasks = s.numPartitions
diff --git a/core/src/main/scala/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/spark/ui/jobs/JobProgressUI.scala
index e7fbff7f73..09d24b6302 100644
--- a/core/src/main/scala/spark/ui/jobs/JobProgressUI.scala
+++ b/core/src/main/scala/spark/ui/jobs/JobProgressUI.scala
@@ -65,6 +65,7 @@ private[spark] class JobProgressListener extends SparkListener {
   val completedStages = ListBuffer[Stage]()
   val failedStages = ListBuffer[Stage]()
 
+  // Total metrics reflect metrics only for completed tasks
   var totalTime = 0L
   var totalShuffleRead = 0L
   var totalShuffleWrite = 0L
@@ -109,10 +110,8 @@ private[spark] class JobProgressListener extends SparkListener {
 
   override def onTaskStart(taskStart: SparkListenerTaskStart) {
     val sid = taskStart.task.stageId
-    if (!stageToTasksActive.contains(sid)) {
-      stageToTasksActive(sid) = HashSet[TaskInfo]()
-    }
-    stageToTasksActive(sid) += taskStart.taskInfo
+    val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
+    tasksActive += taskStart.taskInfo
     val taskList = stageToTaskInfos.getOrElse(
       sid, ArrayBuffer[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
     taskList += ((taskStart.taskInfo, None, None))
@@ -121,10 +120,8 @@ private[spark] class JobProgressListener extends SparkListener {
 
   override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
     val sid = taskEnd.task.stageId
-    if (!stageToTasksActive.contains(sid)) {
-      stageToTasksActive(sid) = HashSet[TaskInfo]()
-    }
-    stageToTasksActive(sid) -= taskEnd.taskInfo
+    val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
+    tasksActive -= taskEnd.taskInfo
     val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
       taskEnd.reason match {
         case e: ExceptionFailure =>
@@ -135,24 +132,18 @@ private[spark] class JobProgressListener extends SparkListener {
           (None, Option(taskEnd.taskMetrics))
       }
 
-    if (!stageToTime.contains(sid)) {
-      stageToTime(sid) = 0L
-    }
+    stageToTime.getOrElseUpdate(sid, 0L)
     val time = metrics.map(m => m.executorRunTime).getOrElse(0)
     stageToTime(sid) += time
     totalTime += time
 
-    if (!stageToShuffleRead.contains(sid)) {
-      stageToShuffleRead(sid) = 0L
-    }
+    stageToShuffleRead.getOrElseUpdate(sid, 0L)
     val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s =>
       s.remoteBytesRead).getOrElse(0L)
     stageToShuffleRead(sid) += shuffleRead
     totalShuffleRead += shuffleRead
 
-    if (!stageToShuffleWrite.contains(sid)) {
-      stageToShuffleWrite(sid) = 0L
-    }
+    stageToShuffleWrite.getOrElseUpdate(sid, 0L)
     val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s =>
       s.shuffleBytesWritten).getOrElse(0L)
     stageToShuffleWrite(sid) += shuffleWrite
@@ -178,22 +169,4 @@ private[spark] class JobProgressListener extends SparkListener {
       case _ =>
     }
   }
-
-  /** Is this stage's input from a shuffle read. */
-  def hasShuffleRead(stageID: Int): Boolean = {
-    // This is written in a slightly complicated way to avoid having to scan all tasks
-    for (s <- stageToTaskInfos.get(stageID).getOrElse(Seq())) {
-      if (s._2 != null) return s._2.flatMap(m => m.shuffleReadMetrics).isDefined
-    }
-    return false // No tasks have finished for this stage
-  }
-
-  /** Is this stage's output to a shuffle write. */
-  def hasShuffleWrite(stageID: Int): Boolean = {
-    // This is written in a slightly complicated way to avoid having to scan all tasks
-    for (s <- stageToTaskInfos.get(stageID).getOrElse(Seq())) {
-      if (s._2 != null) return s._2.flatMap(m => m.shuffleWriteMetrics).isDefined
-    }
-    return false // No tasks have finished for this stage
-  }
 }
-- 
GitLab