From 57579934f0454f258615c10e69ac2adafc5b9835 Mon Sep 17 00:00:00 2001
From: hhd <henrydavidge@gmail.com>
Date: Mon, 25 Nov 2013 17:17:17 -0500
Subject: [PATCH] Emit warning when task size > 100KB

---
 .../org/apache/spark/scheduler/DAGScheduler.scala | 15 +++++++++++++++
 .../org/apache/spark/scheduler/StageInfo.scala    |  1 +
 .../org/apache/spark/scheduler/TaskInfo.scala     |  2 ++
 .../scheduler/cluster/ClusterTaskSetManager.scala |  1 +
 4 files changed, 19 insertions(+)

diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 42bb3884c8..4457525ac8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -110,6 +110,9 @@ class DAGScheduler(
   // resubmit failed stages
   val POLL_TIMEOUT = 10L
 
+  // Warns the user if a stage contains a task with size greater than this value (in KB)
+  val TASK_SIZE_TO_WARN = 100
+
   private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor {
     override def preStart() {
       context.system.scheduler.schedule(RESUBMIT_TIMEOUT milliseconds, RESUBMIT_TIMEOUT milliseconds) {
@@ -430,6 +433,18 @@ class DAGScheduler(
         handleExecutorLost(execId)
 
       case BeginEvent(task, taskInfo) =>
+        for (
+          job <- idToActiveJob.get(task.stageId);
+          stage <- stageIdToStage.get(task.stageId);
+          stageInfo <- stageToInfos.get(stage)
+        ) {
+          if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) {
+            stageInfo.emittedTaskSizeWarning = true
+            logWarning(("Stage %d (%s) contains a task of very large " +
+              "size (%d KB). The maximum recommended task size is %d KB.").format(
+              task.stageId, stageInfo.name, taskInfo.serializedSize / 1024, TASK_SIZE_TO_WARN))
+          }
+        }
         listenerBus.post(SparkListenerTaskStart(task, taskInfo))
 
       case GettingResultEvent(task, taskInfo) =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index 93599dfdc8..e9f2198a00 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -33,4 +33,5 @@ class StageInfo(
   val name = stage.name
   val numPartitions = stage.numPartitions
   val numTasks = stage.numTasks
+  var emittedTaskSizeWarning = false
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 4bae26f3a6..3c22edd524 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -46,6 +46,8 @@ class TaskInfo(
 
   var failed = false
 
+  var serializedSize: Int = 0
+
   def markGettingResult(time: Long = System.currentTimeMillis) {
     gettingResultTime = time
   }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 4c5eca8537..8884ea85a3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -377,6 +377,7 @@ private[spark] class ClusterTaskSetManager(
           logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
             taskSet.id, index, serializedTask.limit, timeTaken))
           val taskName = "task %s:%d".format(taskSet.id, index)
+          info.serializedSize = serializedTask.limit
           if (taskAttempts(index).size == 1)
             taskStarted(task,info)
           return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
-- 
GitLab