diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dda194d9537c879d630ceba631d9301990d9c146..4cef0825dd6c0aab711df8a58700bd37fb91c0e0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -68,6 +68,11 @@ class DAGScheduler( eventQueue.put(BeginEvent(task, taskInfo)) } + // Called to report that a task has completed and results are being fetched remotely. + def taskGettingResult(task: Task[_], taskInfo: TaskInfo) { + eventQueue.put(GettingResultEvent(task, taskInfo)) + } + // Called by TaskScheduler to report task completions or failures. def taskEnded( task: Task[_], @@ -415,6 +420,9 @@ class DAGScheduler( case begin: BeginEvent => listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo)) + case gettingResult: GettingResultEvent => + listenerBus.post(SparkListenerTaskGettingResult(gettingResult.task, gettingResult.taskInfo)) + case completion: CompletionEvent => listenerBus.post(SparkListenerTaskEnd( completion.task, completion.reason, completion.taskInfo, completion.taskMetrics)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index a5769c604195b572196c4171141cdd95a5f81bad..708d221d60caf8cd981780d513a89d372a1baf97 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -53,6 +53,9 @@ private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent private[scheduler] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent +private[scheduler] +case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent + private[scheduler] case class CompletionEvent( task: Task[_], reason: TaskEndReason, diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 324cd639b0a5710a74006ef19963903800dff521..a35081f7b10d7040d8b45302ce50941cef3e7960 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -31,6 +31,9 @@ case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents +case class SparkListenerTaskGettingResult( + task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents + case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, taskMetrics: TaskMetrics) extends SparkListenerEvents @@ -56,6 +59,12 @@ trait SparkListener { */ def onTaskStart(taskStart: SparkListenerTaskStart) { } + /** + * Called when a task begins remotely fetching its result (will not be called for tasks that do + * not need to fetch the result remotely). + */ + def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } + /** * Called when a task ends */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 4d3e4a17ba5620281d0b0e7238fd39907d814e46..d5824e79547974e643b348b12465fa6fe78a2fe0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -49,6 +49,8 @@ private[spark] class SparkListenerBus() extends Logging { sparkListeners.foreach(_.onJobEnd(jobEnd)) case taskStart: SparkListenerTaskStart => sparkListeners.foreach(_.onTaskStart(taskStart)) + case taskGettingResult: SparkListenerTaskGettingResult => + sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult)) case taskEnd: SparkListenerTaskEnd => sparkListeners.foreach(_.onTaskEnd(taskEnd)) case _ => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 7c2a422affbbfbe4be65817121d22c06f1bb3dfd..4bae26f3a6a885c73bd1639d61d226cbd06a5ea2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -31,9 +31,25 @@ class TaskInfo( val host: String, val taskLocality: TaskLocality.TaskLocality) { + /** + * The time when the task started remotely getting the result. Will not be set if the + * task result was sent immediately when the task finished (as opposed to sending an + * IndirectTaskResult and later fetching the result from the block manager). + */ + var gettingResultTime: Long = 0 + + /** + * The time when the task has completed successfully (including the time to remotely fetch + * results, if necessary). + */ var finishTime: Long = 0 + var failed = false + def markGettingResult(time: Long = System.currentTimeMillis) { + gettingResultTime = time + } + def markSuccessful(time: Long = System.currentTimeMillis) { finishTime = time } @@ -43,6 +59,8 @@ class TaskInfo( failed = true } + def gettingResult: Boolean = gettingResultTime != 0 + def finished: Boolean = finishTime != 0 def successful: Boolean = finished && !failed @@ -52,6 +70,8 @@ class TaskInfo( def status: String = { if (running) "RUNNING" + else if (gettingResult) + "GET RESULT" else if (failed) "FAILED" else if (successful) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index 4ea8bf88534cf1d90b691b5265101a9e235eaf5d..85033958ef54f4e1568a3023630621e2f9cd7b35 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -306,6 +306,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } + def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) { + taskSetManager.handleTaskGettingResult(tid) + } + def handleSuccessfulTask( taskSetManager: ClusterTaskSetManager, tid: Long, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 29093e3b4f511b1fe9d1e4fcadc9d20ac6e3d1bc..ee47aaffcae11d1e341626791047c5e1ae2bc9ca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -418,6 +418,12 @@ private[spark] class ClusterTaskSetManager( sched.dagScheduler.taskStarted(task, info) } + def handleTaskGettingResult(tid: Long) = { + val info = taskInfos(tid) + info.markGettingResult() + sched.dagScheduler.taskGettingResult(tasks(info.index), info) + } + /** * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala index 4312c46cc190c1279318942f0c150e743c36fe14..2064d97b49cc04f35cd638a65c57112abf3d9956 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala @@ -50,6 +50,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche case directResult: DirectTaskResult[_] => directResult case IndirectTaskResult(blockId) => logDebug("Fetching indirect task result for TID %s".format(tid)) + scheduler.handleTaskGettingResult(taskSetManager, tid) val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId) if (!serializedTaskResult.isDefined) { /* We won't be able to get the task result if the machine that ran the task failed diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 9bb8a13ec45d86c5dda068aea29db989acfc8dcf..6b854740d6a2425e51cac0505a2e6ecc58cecf30 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -115,7 +115,13 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList taskList += ((taskStart.taskInfo, None, None)) stageIdToTaskInfos(sid) = taskList } - + + override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) + = synchronized { + // Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in + // stageToTaskInfos already has the updated status. + } + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val sid = taskEnd.task.stageId val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 42ca988f7a12995b2e6f322e587cfae7932b9b39..f7f599532a96c3c61ec8a1ed46359dca8a2f5b90 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -17,22 +17,25 @@ package org.apache.spark.scheduler -import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.{LocalSparkContext, SparkContext} -import scala.collection.mutable +import scala.collection.mutable.{Buffer, HashSet} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.matchers.ShouldMatchers + +import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.SparkContext._ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers - with BeforeAndAfter { + with BeforeAndAfterAll { /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 - before { - sc = new SparkContext("local", "DAGSchedulerSuite") + override def afterAll { + System.clearProperty("spark.akka.frameSize") } test("basic creation of StageInfo") { + sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -53,6 +56,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("StageInfo with fewer tasks than partitions") { + sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -68,6 +72,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("local metrics") { + sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) @@ -129,15 +134,73 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } } + test("onTaskGettingResult() called when result fetched remotely") { + // Need to use local cluster mode here, because results are not ever returned through the + // block manager when using the LocalScheduler. + sc = new SparkContext("local-cluster[1,1,512]", "test") + + val listener = new SaveTaskEvents + sc.addSparkListener(listener) + + // Make a task whose result is larger than the akka frame size + System.setProperty("spark.akka.frameSize", "1") + val akkaFrameSize = + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x,y) => x) + assert(result === 1.to(akkaFrameSize).toArray) + + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + val TASK_INDEX = 0 + assert(listener.startedTasks.contains(TASK_INDEX)) + assert(listener.startedGettingResultTasks.contains(TASK_INDEX)) + assert(listener.endedTasks.contains(TASK_INDEX)) + } + + test("onTaskGettingResult() not called when result sent directly") { + // Need to use local cluster mode here, because results are not ever returned through the + // block manager when using the LocalScheduler. + sc = new SparkContext("local-cluster[1,1,512]", "test") + + val listener = new SaveTaskEvents + sc.addSparkListener(listener) + + // Make a task whose result is larger than the akka frame size + val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) + assert(result === 2) + + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + val TASK_INDEX = 0 + assert(listener.startedTasks.contains(TASK_INDEX)) + assert(listener.startedGettingResultTasks.isEmpty == true) + assert(listener.endedTasks.contains(TASK_INDEX)) + } + def checkNonZeroAvg(m: Traversable[Long], msg: String) { assert(m.sum / m.size.toDouble > 0.0, msg) } class SaveStageInfo extends SparkListener { - val stageInfos = mutable.Buffer[StageInfo]() + val stageInfos = Buffer[StageInfo]() override def onStageCompleted(stage: StageCompleted) { stageInfos += stage.stage } } + class SaveTaskEvents extends SparkListener { + val startedTasks = new HashSet[Int]() + val startedGettingResultTasks = new HashSet[Int]() + val endedTasks = new HashSet[Int]() + + override def onTaskStart(taskStart: SparkListenerTaskStart) { + startedTasks += taskStart.taskInfo.index + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + endedTasks += taskEnd.taskInfo.index + } + + override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { + startedGettingResultTasks += taskGettingResult.taskInfo.index + } + } }