diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 7bf50de660ec3302f31d95ec68f2534a3b9795bd..9b45fc29388eaba971624bd031b891abbeac2de5 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -17,19 +17,17 @@ package spark.scheduler -import cluster.TaskInfo -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.LinkedBlockingQueue -import java.util.concurrent.TimeUnit +import java.io.NotSerializableException import java.util.Properties +import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import spark._ import spark.executor.TaskMetrics -import spark.partial.ApproximateActionListener -import spark.partial.ApproximateEvaluator -import spark.partial.PartialResult +import spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} +import spark.scheduler.cluster.TaskInfo import spark.storage.{BlockManager, BlockManagerMaster} import spark.util.{MetadataCleaner, TimeStampedHashMap} @@ -263,7 +261,8 @@ class DAGScheduler( assert(partitions.size > 0) val waiter = new JobWaiter(partitions.size, resultHandler) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) + val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, + properties) return (toSubmit, waiter) } @@ -288,7 +287,7 @@ class DAGScheduler( "Total number of partitions: " + maxPartitions) } - val (toSubmit, waiter) = prepareJob( + val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob( finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties) eventQueue.put(toSubmit) waiter.awaitResult() match { @@ -512,6 +511,19 @@ class DAGScheduler( } } if (tasks.size > 0) { + // Preemptively serialize a task to make sure it can be serialized. We are catching this + // exception here because it would be fairly hard to catch the non-serializable exception + // down the road, where we have several different implementations for local scheduler and + // cluster schedulers. + try { + SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head) + } catch { + case e: NotSerializableException => + abortStage(stage, e.toString) + running -= stage + return + } + sparkListeners.foreach(_.onStageSubmitted(SparkListenerStageSubmitted(stage, tasks.size))) logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") myPending ++= tasks diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala index 6c847b8fef8562886accfd13407aa110007a5277..5b133cdd6e867947bb324b068fe352cefd3c1178 100644 --- a/core/src/test/scala/spark/FailureSuite.scala +++ b/core/src/test/scala/spark/FailureSuite.scala @@ -18,9 +18,6 @@ package spark import org.scalatest.FunSuite -import org.scalatest.prop.Checkers - -import scala.collection.mutable.ArrayBuffer import SparkContext._ @@ -40,7 +37,7 @@ object FailureSuiteState { } class FailureSuite extends FunSuite with LocalSparkContext { - + // Run a 3-task map job in which task 1 deterministically fails once, and check // whether the job completes successfully and we ran 4 tasks in total. test("failure in a single-stage job") { @@ -66,7 +63,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { test("failure in a two-stage job") { sc = new SparkContext("local[1,1]", "test") val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map { - case (k, v) => + case (k, v) => FailureSuiteState.synchronized { FailureSuiteState.tasksRun += 1 if (k == 1 && FailureSuiteState.tasksFailed == 0) { @@ -87,12 +84,40 @@ class FailureSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local[1,1]", "test") val results = sc.makeRDD(1 to 3).map(x => new NonSerializable) - val thrown = intercept[spark.SparkException] { + val thrown = intercept[SparkException] { results.collect() } - assert(thrown.getClass === classOf[spark.SparkException]) + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getMessage.contains("NotSerializableException")) + + FailureSuiteState.clear() + } + + test("failure because task closure is not serializable") { + sc = new SparkContext("local[1,1]", "test") + val a = new NonSerializable + + // Non-serializable closure in the final result stage + val thrown = intercept[SparkException] { + sc.parallelize(1 to 10, 2).map(x => a).count() + } + assert(thrown.getClass === classOf[SparkException]) assert(thrown.getMessage.contains("NotSerializableException")) + // Non-serializable closure in an earlier stage + val thrown1 = intercept[SparkException] { + sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count() + } + assert(thrown1.getClass === classOf[SparkException]) + assert(thrown1.getMessage.contains("NotSerializableException")) + + // Non-serializable closure in foreach function + val thrown2 = intercept[SparkException] { + sc.parallelize(1 to 10, 2).foreach(x => println(a)) + } + assert(thrown2.getClass === classOf[SparkException]) + assert(thrown2.getMessage.contains("NotSerializableException")) + FailureSuiteState.clear() }