diff --git a/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala b/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala new file mode 100644 index 0000000000000000000000000000000000000000..9df61062e1f854bcd10a961da41d3e6400e827c5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.annotation.DeveloperApi + +/** + * Exception thrown when a task cannot be serialized. + */ +private[spark] class TaskNotSerializableException(error: Throwable) extends Exception(error) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 259621d263d7cc297f1359e05870879377672e9b..61d09d73e17cb69185f80ef5da6ac70b1e7b7306 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -866,26 +866,6 @@ class DAGScheduler( } if (tasks.size > 0) { - // Preemptively serialize a task to make sure it can be serialized. We are catching this - // exception here because it would be fairly hard to catch the non-serializable exception - // down the road, where we have several different implementations for local scheduler and - // cluster schedulers. - // - // We've already serialized RDDs and closures in taskBinary, but here we check for all other - // objects such as Partition. - try { - closureSerializer.serialize(tasks.head) - } catch { - case e: NotSerializableException => - abortStage(stage, "Task not serializable: " + e.toString) - runningStages -= stage - return - case NonFatal(e) => // Other exceptions, such as IllegalArgumentException from Kryo. - abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") - runningStages -= stage - return - } - logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") stage.pendingTasks ++= tasks logDebug("New pending tasks: " + stage.pendingTasks) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index a41f3eef195d28abeef75887758c1e780fc5d3fa..a1dfb0106259124c968c1038ab05c2d10140a2c1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -31,6 +31,7 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.util.Utils import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -209,6 +210,40 @@ private[spark] class TaskSchedulerImpl( .format(manager.taskSet.id, manager.parent.name)) } + private def resourceOfferSingleTaskSet( + taskSet: TaskSetManager, + maxLocality: TaskLocality, + shuffledOffers: Seq[WorkerOffer], + availableCpus: Array[Int], + tasks: Seq[ArrayBuffer[TaskDescription]]) : Boolean = { + var launchedTask = false + for (i <- 0 until shuffledOffers.size) { + val execId = shuffledOffers(i).executorId + val host = shuffledOffers(i).host + if (availableCpus(i) >= CPUS_PER_TASK) { + try { + for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { + tasks(i) += task + val tid = task.taskId + taskIdToTaskSetId(tid) = taskSet.taskSet.id + taskIdToExecutorId(tid) = execId + executorsByHost(host) += execId + availableCpus(i) -= CPUS_PER_TASK + assert(availableCpus(i) >= 0) + launchedTask = true + } + } catch { + case e: TaskNotSerializableException => + logError(s"Resource offer failed, task set ${taskSet.name} was not serializable") + // Do not offer resources for this task, but don't throw an error to allow other + // task sets to be submitted. + return launchedTask + } + } + } + return launchedTask + } + /** * Called by cluster manager to offer resources on slaves. We respond by asking our active task * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so @@ -251,23 +286,8 @@ private[spark] class TaskSchedulerImpl( var launchedTask = false for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) { do { - launchedTask = false - for (i <- 0 until shuffledOffers.size) { - val execId = shuffledOffers(i).executorId - val host = shuffledOffers(i).host - if (availableCpus(i) >= CPUS_PER_TASK) { - for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { - tasks(i) += task - val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id - taskIdToExecutorId(tid) = execId - executorsByHost(host) += execId - availableCpus(i) -= CPUS_PER_TASK - assert(availableCpus(i) >= 0) - launchedTask = true - } - } - } + launchedTask = resourceOfferSingleTaskSet( + taskSet, maxLocality, shuffledOffers, availableCpus, tasks) } while (launchedTask) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 28e6147509f78c0aef1af736b69561ab0a44b951..4667850917151ec4fb9c4f4a67627762ecdbb12a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -18,12 +18,14 @@ package org.apache.spark.scheduler import java.io.NotSerializableException +import java.nio.ByteBuffer import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import scala.math.{min, max} +import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.executor.TaskMetrics @@ -417,6 +419,7 @@ private[spark] class TaskSetManager( * @param host the host Id of the offered resource * @param maxLocality the maximum locality we want to schedule the tasks at */ + @throws[TaskNotSerializableException] def resourceOffer( execId: String, host: String, @@ -456,10 +459,17 @@ private[spark] class TaskSetManager( } // Serialize and return the task val startTime = clock.getTime() - // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here - // we assume the task can be serialized without exceptions. - val serializedTask = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) + val serializedTask: ByteBuffer = try { + Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) + } catch { + // If the task cannot be serialized, then there's no point to re-attempt the task, + // as it will always fail. So just abort the whole task-set. + case NonFatal(e) => + val msg = s"Failed to serialize task $taskId, not attempting to retry it." + logError(msg, e) + abort(s"$msg Exception during serialization: $e") + throw new TaskNotSerializableException(e) + } if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && !emittedTaskSizeWarning) { emittedTaskSizeWarning = true diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 0b6511a80df1dd6eaf934709864d9dd0ca4827a4..3d2700b7e6be45b51324a76873be22c643413147 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -30,7 +30,7 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => var conf = new SparkConf(false) override def beforeAll() { - _sc = new SparkContext("local", "test", conf) + _sc = new SparkContext("local[4]", "test", conf) super.beforeAll() } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6836e9ab0fd6bf7b6f5f9dc6fb09c5625b423ae8..0deb9b18b868887fba1b1f761e9ddba495e00816 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.rdd +import java.io.{ObjectInputStream, ObjectOutputStream, IOException} + +import com.esotericsoftware.kryo.KryoException + import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -887,6 +891,23 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(ancestors6.count(_.isInstanceOf[CyclicalDependencyRDD[_]]) === 3) } + test("task serialization exception should not hang scheduler") { + class BadSerializable extends Serializable { + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream): Unit = throw new KryoException("Bad serialization") + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = {} + } + // Note that in the original bug, SPARK-4349, that this verifies, the job would only hang if there were + // more threads in the Spark Context than there were number of objects in this sequence. + intercept[Throwable] { + sc.parallelize(Seq(new BadSerializable, new BadSerializable)).collect + } + // Check that the context has not crashed + sc.parallelize(1 to 100).map(x => x*2).collect + } + /** A contrived RDD that allows the manual addition of dependencies after creation. */ private class CyclicalDependencyRDD[T: ClassTag] extends RDD[T](sc, Nil) { private val mutableDependencies: ArrayBuffer[Dependency[_]] = ArrayBuffer.empty diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala new file mode 100644 index 0000000000000000000000000000000000000000..6b75c98839e033e2da8e5e52383a252c74777025 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.{ObjectInputStream, ObjectOutputStream, IOException} + +import org.apache.spark.TaskContext + +/** + * A Task implementation that fails to serialize. + */ +private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) extends Task[Array[Byte]](stageId, 0) { + override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] + override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() + + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream): Unit = { + if (stageId == 0) { + throw new IllegalStateException("Cannot serialize") + } + } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = {} +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 8874cf00e999336819bca053b664a6477b89d3a9..add13f5b217653eba1e280c015b1da65941ca40d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -100,4 +100,34 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin assert(1 === taskDescriptions.length) assert("executor0" === taskDescriptions(0).executorId) } + + test("Scheduler does not crash when tasks are not serializable") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskCpus = 2 + + sc.conf.set("spark.task.cpus", taskCpus.toString) + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + val dagScheduler = new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + val numFreeCores = 1 + taskScheduler.setDAGScheduler(dagScheduler) + var taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) + val multiCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", taskCpus), + new WorkerOffer("executor1", "host1", numFreeCores)) + taskScheduler.submitTasks(taskSet) + var taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten + assert(0 === taskDescriptions.length) + + // Now check that we can still submit tasks + // Even if one of the tasks has not-serializable tasks, the other task set should still be processed without error + taskScheduler.submitTasks(taskSet) + taskScheduler.submitTasks(FakeTask.createTaskSet(1)) + taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten + assert(taskDescriptions.map(_.executorId) === Seq("executor0")) + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 472191551a01ffa2ca6a8622d62f3c96c47966a1..84b9b788237bfa83161485ee86c56d214d713f1a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.{ObjectInputStream, ObjectOutputStream, IOException} import java.util.Random import scala.collection.mutable.ArrayBuffer @@ -563,6 +564,19 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.emittedTaskSizeWarning) } + test("Not serializable exception thrown if the task cannot be serialized") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + + val taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) + + intercept[TaskNotSerializableException] { + manager.resourceOffer("exec1", "host1", ANY) + } + assert(manager.isZombie) + } + test("abort the job if total size of results is too large") { val conf = new SparkConf().set("spark.driver.maxResultSize", "2m") sc = new SparkContext("local", "test", conf)