diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c78e0ffca25bba8166cdc4a89c3cce5e85d1b06a..e24a15f015e1ce15a3688fc4a8def39614cbf295 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -104,6 +104,9 @@ private[spark] class Executor( // to send the result back. private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + // Limit of bytes for total size of results (default is 1GB) + private val maxResultSize = Utils.getMaxResultSize(conf) + // Start worker thread pool val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") @@ -210,25 +213,27 @@ private[spark] class Executor( val resultSize = serializedDirectResult.limit // directSend = sending directly back to the driver - val (serializedResult, directSend) = { - if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { + val serializedResult = { + if (resultSize > maxResultSize) { + logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + + s"dropping it.") + ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) + } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) - (ser.serialize(new IndirectTaskResult[Any](blockId)), false) + logInfo( + s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") + ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) } else { - (serializedDirectResult, true) + logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") + serializedDirectResult } } execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) - if (directSend) { - logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") - } else { - logInfo( - s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") - } } catch { case ffe: FetchFailedException => { val reason = ffe.toTaskEndReason diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 11c19eeb6e42c4568723bfe8478e689a6647ccaa..1f114a0207f7bc0e0554493c912b387b0d6bea66 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -31,8 +31,8 @@ import org.apache.spark.util.Utils private[spark] sealed trait TaskResult[T] /** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */ -private[spark] -case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable +private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int) + extends TaskResult[T] with Serializable /** A TaskResult that contains the task's return value and accumulator updates. */ private[spark] diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 4b5be68ec5f928a6e7ad8ac8446a5ce5554225ac..819b51e12ad8c954663e11f68efc4bdfd759bdbc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -47,9 +47,18 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { try { - val result = serializer.get().deserialize[TaskResult[_]](serializedData) match { - case directResult: DirectTaskResult[_] => directResult - case IndirectTaskResult(blockId) => + val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match { + case directResult: DirectTaskResult[_] => + if (!taskSetManager.canFetchMoreResults(serializedData.limit())) { + return + } + (directResult, serializedData.limit()) + case IndirectTaskResult(blockId, size) => + if (!taskSetManager.canFetchMoreResults(size)) { + // dropped by executor if size is larger than maxResultSize + sparkEnv.blockManager.master.removeBlock(blockId) + return + } logDebug("Fetching indirect task result for TID %s".format(tid)) scheduler.handleTaskGettingResult(taskSetManager, tid) val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId) @@ -64,9 +73,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]]( serializedTaskResult.get) sparkEnv.blockManager.master.removeBlock(blockId) - deserializedResult + (deserializedResult, size) } - result.metrics.resultSize = serializedData.limit() + + result.metrics.resultSize = size scheduler.handleSuccessfulTask(taskSetManager, tid, result) } catch { case cnf: ClassNotFoundException => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 376821f89c6b88b5f8389279e6ae8e8e7530b652..a9767340074a8dcd00a42f40713187eea493ac3a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -23,13 +23,12 @@ import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min +import scala.math.{min, max} import org.apache.spark._ -import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.{Clock, SystemClock} +import org.apache.spark.TaskState.TaskState +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of @@ -68,6 +67,9 @@ private[spark] class TaskSetManager( val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) + // Limit of bytes for total size of results (default is 1GB) + val maxResultSize = Utils.getMaxResultSize(conf) + // Serializer for closures and tasks. val env = SparkEnv.get val ser = env.closureSerializer.newInstance() @@ -89,6 +91,8 @@ private[spark] class TaskSetManager( var stageId = taskSet.stageId var name = "TaskSet_" + taskSet.stageId.toString var parent: Pool = null + var totalResultSize = 0L + var calculatedTasks = 0 val runningTasksSet = new HashSet[Long] override def runningTasks = runningTasksSet.size @@ -515,12 +519,33 @@ private[spark] class TaskSetManager( index } + /** + * Marks the task as getting result and notifies the DAG Scheduler + */ def handleTaskGettingResult(tid: Long) = { val info = taskInfos(tid) info.markGettingResult() sched.dagScheduler.taskGettingResult(info) } + /** + * Check whether has enough quota to fetch the result with `size` bytes + */ + def canFetchMoreResults(size: Long): Boolean = synchronized { + totalResultSize += size + calculatedTasks += 1 + if (maxResultSize > 0 && totalResultSize > maxResultSize) { + val msg = s"Total size of serialized results of ${calculatedTasks} tasks " + + s"(${Utils.bytesToString(totalResultSize)}) is bigger than maxResultSize " + + s"(${Utils.bytesToString(maxResultSize)})" + logError(msg) + abort(msg) + false + } else { + true + } + } + /** * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 68d378f3a212df956ac8df227ea241ed61654ebc..4e30d0d3813a2b4ba3886832170516f84650732d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1720,6 +1720,11 @@ private[spark] object Utils extends Logging { method.invoke(obj, values.toSeq: _*) } + // Limit of bytes for total size of results (default is 1GB) + def getMaxResultSize(conf: SparkConf): Long = { + memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20 + } + /** * Return the current system LD_LIBRARY_PATH name */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index c4e7a4bb7d38534fa5296a041821f08553fa9ea7..5768a3a733f00ac759cd82d9eddba5583d5cc396 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -40,7 +40,7 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule // Only remove the result once, since we'd like to test the case where the task eventually // succeeds. serializer.get().deserialize[TaskResult[_]](serializedData) match { - case IndirectTaskResult(blockId) => + case IndirectTaskResult(blockId, size) => sparkEnv.blockManager.master.removeBlock(blockId) case directResult: DirectTaskResult[_] => taskSetManager.abort("Internal error: expect only indirect results") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index c0b07649eb6dd78aead0b983b693207f34a6dcbc..1809b5396d53eac0e790d71bd95a15ec258053bb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -563,6 +563,31 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.emittedTaskSizeWarning) } + test("abort the job if total size of results is too large") { + val conf = new SparkConf().set("spark.driver.maxResultSize", "2m") + sc = new SparkContext("local", "test", conf) + + def genBytes(size: Int) = { (x: Int) => + val bytes = Array.ofDim[Byte](size) + scala.util.Random.nextBytes(bytes) + bytes + } + + // multiple 1k result + val r = sc.makeRDD(0 until 10, 10).map(genBytes(1024)).collect() + assert(10 === r.size ) + + // single 10M result + val thrown = intercept[SparkException] {sc.makeRDD(genBytes(10 << 20)(0), 1).collect()} + assert(thrown.getMessage().contains("bigger than maxResultSize")) + + // multiple 1M results + val thrown2 = intercept[SparkException] { + sc.makeRDD(0 until 10, 10).map(genBytes(1 << 20)).collect() + } + assert(thrown2.getMessage().contains("bigger than maxResultSize")) + } + test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) diff --git a/docs/configuration.md b/docs/configuration.md index 3007706a2586e757c4b60f86e2569aa37e9af9b0..099972ca1af703c53920b59c1fb103db01fab59b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -111,6 +111,18 @@ of the most common options to set are: (e.g. <code>512m</code>, <code>2g</code>). </td> </tr> +<tr> + <td><code>spark.driver.maxResultSize</code></td> + <td>1g</td> + <td> + Limit of total size of serialized results of all partitions for each Spark action (e.g. collect). + Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size + is above this limit. + Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory + and memory overhead of objects in JVM). Setting a proper limit can protect the driver from + out-of-memory errors. + </td> +</tr> <tr> <td><code>spark.serializer</code></td> <td>org.apache.spark.serializer.<br />JavaSerializer</td>