Skip to content
Snippets Groups Projects
Commit 20f01a0a authored by Imran Rashid's avatar Imran Rashid
Browse files

enable task metrics in local mode, add tests

parent ec30188a
No related branches found
No related tags found
No related merge requests found
......@@ -67,8 +67,10 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
updateDependencies(taskFiles, taskJars) // Download any files added with addFile
val deserStart = System.currentTimeMillis()
val deserializedTask = ser.deserialize[Task[_]](
taskBytes, Thread.currentThread.getContextClassLoader)
val deserTime = System.currentTimeMillis() - deserStart
// Run it
val result: Any = deserializedTask.run(attemptId)
......@@ -77,15 +79,19 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
// to the cluster, they don't get surprised by serialization errors).
val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val serResult = ser.serialize(result)
deserializedTask.metrics.get.resultSize = serResult.limit()
val resultToReturn = ser.deserialize[Any](serResult)
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
logInfo("Finished " + task)
info.markSuccessful()
deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough
deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
// If the threadpool has not already been shutdown, notify DAGScheduler
if (!Thread.currentThread().isInterrupted)
listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, null)
listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null))
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
......
package spark.scheduler
import org.scalatest.FunSuite
import spark.{SparkContext, LocalSparkContext}
import scala.collection.mutable
import org.scalatest.matchers.ShouldMatchers
import spark.SparkContext._
/**
*
*/
class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
test("local metrics") {
sc = new SparkContext("local[4]", "test")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
val d = sc.parallelize(1 to 1e4.toInt, 64)
d.count
listener.stageInfos.size should be (1)
val d2 = d.map{i => i -> i * 2}.setName("shuffle input 1")
val d3 = d.map{i => i -> (0 to (i % 5))}.setName("shuffle input 2")
val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => k -> (v1.size, v2.size)}
d4.setName("A Cogroup")
d4.collectAsMap
listener.stageInfos.size should be (4)
listener.stageInfos.foreach {stageInfo =>
//small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms
checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration")
checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime")
checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime")
if (stageInfo.stage.rdd.name == d4.name) {
checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime")
}
stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) =>
taskMetrics.resultSize should be > (0l)
if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
taskMetrics.shuffleWriteMetrics should be ('defined)
taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l)
}
if (stageInfo.stage.rdd.name == d4.name) {
taskMetrics.shuffleReadMetrics should be ('defined)
val sm = taskMetrics.shuffleReadMetrics.get
sm.totalBlocksFetched should be > (0)
sm.shuffleReadMillis should be > (0l)
sm.localBlocksFetched should be > (0)
sm.remoteBlocksFetched should be (0)
sm.remoteBytesRead should be (0l)
sm.remoteFetchTime should be (0l)
}
}
}
}
def checkNonZeroAvg(m: Traversable[Long], msg: String) {
assert(m.sum / m.size.toDouble > 0.0, msg)
}
def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = {
val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name}
!names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty
}
class SaveStageInfo extends SparkListener {
val stageInfos = mutable.Buffer[StageInfo]()
def onStageCompleted(stage: StageCompleted) {
stageInfos += stage.stageInfo
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment