From 20f01a0a1be1349990bb86426f99f4f446432f0c Mon Sep 17 00:00:00 2001
From: Imran Rashid <imran@quantifind.com>
Date: Sat, 9 Mar 2013 21:17:31 -0800
Subject: [PATCH] enable task metrics in local mode, add tests

---
 .../scheduler/local/LocalScheduler.scala      | 10 ++-
 .../spark/scheduler/SparkListenerSuite.scala  | 80 +++++++++++++++++++
 2 files changed, 88 insertions(+), 2 deletions(-)
 create mode 100644 core/src/test/scala/spark/scheduler/SparkListenerSuite.scala

diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index a76253ea14..9e1bde3fbe 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -67,8 +67,10 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
         logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
         val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
         updateDependencies(taskFiles, taskJars)   // Download any files added with addFile
+        val deserStart = System.currentTimeMillis()
         val deserializedTask = ser.deserialize[Task[_]](
             taskBytes, Thread.currentThread.getContextClassLoader)
+        val deserTime = System.currentTimeMillis() - deserStart
 
         // Run it
         val result: Any = deserializedTask.run(attemptId)
@@ -77,15 +79,19 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
         // executor does. This is useful to catch serialization errors early
         // on in development (so when users move their local Spark programs
         // to the cluster, they don't get surprised by serialization errors).
-        val resultToReturn = ser.deserialize[Any](ser.serialize(result))
+        val serResult = ser.serialize(result)
+        deserializedTask.metrics.get.resultSize = serResult.limit()
+        val resultToReturn = ser.deserialize[Any](serResult)
         val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
           ser.serialize(Accumulators.values))
         logInfo("Finished " + task)
         info.markSuccessful()
+        deserializedTask.metrics.get.executorRunTime = info.duration.toInt  //close enough
+        deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
 
         // If the threadpool has not already been shutdown, notify DAGScheduler
         if (!Thread.currentThread().isInterrupted)
-          listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, null)
+          listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null))
       } catch {
         case t: Throwable => {
           logError("Exception in task " + idInJob, t)
diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
new file mode 100644
index 0000000000..dd9f2d7e91
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
@@ -0,0 +1,80 @@
+package spark.scheduler
+
+import org.scalatest.FunSuite
+import spark.{SparkContext, LocalSparkContext}
+import scala.collection.mutable
+import org.scalatest.matchers.ShouldMatchers
+import spark.SparkContext._
+
+/**
+ *
+ */
+
+class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+
+  test("local metrics") {
+    sc = new SparkContext("local[4]", "test")
+    val listener = new SaveStageInfo
+    sc.addSparkListener(listener)
+    sc.addSparkListener(new StatsReportListener)
+
+    val d = sc.parallelize(1 to 1e4.toInt, 64)
+    d.count
+    listener.stageInfos.size should be (1)
+
+    val d2 = d.map{i => i -> i * 2}.setName("shuffle input 1")
+
+    val d3 = d.map{i => i -> (0 to (i % 5))}.setName("shuffle input 2")
+
+    val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => k -> (v1.size, v2.size)}
+    d4.setName("A Cogroup")
+
+    d4.collectAsMap
+
+    listener.stageInfos.size should be (4)
+    listener.stageInfos.foreach {stageInfo =>
+      //small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms
+      checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration")
+      checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime")
+      checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime")
+      if (stageInfo.stage.rdd.name == d4.name) {
+        checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime")
+      }
+
+        stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) =>
+        taskMetrics.resultSize should be > (0l)
+        if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
+          taskMetrics.shuffleWriteMetrics should be ('defined)
+          taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l)
+        }
+        if (stageInfo.stage.rdd.name == d4.name) {
+          taskMetrics.shuffleReadMetrics should be ('defined)
+          val sm = taskMetrics.shuffleReadMetrics.get
+          sm.totalBlocksFetched should be > (0)
+          sm.shuffleReadMillis should be > (0l)
+          sm.localBlocksFetched should be > (0)
+          sm.remoteBlocksFetched should be (0)
+          sm.remoteBytesRead should be (0l)
+          sm.remoteFetchTime should be (0l)
+        }
+      }
+    }
+  }
+
+  def checkNonZeroAvg(m: Traversable[Long], msg: String) {
+    assert(m.sum / m.size.toDouble > 0.0, msg)
+  }
+
+  def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = {
+    val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name}
+    !names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty
+  }
+
+  class SaveStageInfo extends SparkListener {
+    val stageInfos = mutable.Buffer[StageInfo]()
+    def onStageCompleted(stage: StageCompleted) {
+      stageInfos += stage.stageInfo
+    }
+  }
+
+}
-- 
GitLab