diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 29e879aa4259c41d7042a71617ee30fe4645462f..5fcd807aff2980a62f4ef8ebd109770570c29ab7 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -17,19 +17,17 @@ package spark.scheduler -import cluster.TaskInfo -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.LinkedBlockingQueue -import java.util.concurrent.TimeUnit +import java.io.NotSerializableException import java.util.Properties +import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import spark._ import spark.executor.TaskMetrics -import spark.partial.ApproximateActionListener -import spark.partial.ApproximateEvaluator -import spark.partial.PartialResult +import spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} +import spark.scheduler.cluster.TaskInfo import spark.storage.{BlockManager, BlockManagerMaster} import spark.util.{MetadataCleaner, TimeStampedHashMap} @@ -258,7 +256,8 @@ class DAGScheduler( assert(partitions.size > 0) val waiter = new JobWaiter(partitions.size, resultHandler) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) + val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, + properties) return (toSubmit, waiter) } @@ -283,7 +282,7 @@ class DAGScheduler( "Total number of partitions: " + maxPartitions) } - val (toSubmit, waiter) = prepareJob( + val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob( finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties) eventQueue.put(toSubmit) waiter.awaitResult() match { @@ -466,6 +465,18 @@ class DAGScheduler( /** Submits stage, but first recursively submits any missing parents. */ private def submitStage(stage: Stage) { logDebug("submitStage(" + stage + ")") + + // Preemptively serialize the stage RDD to make sure the tasks for this stage will be + // serializable. We are catching this exception here because it would be fairly hard to + // catch the non-serializable exception down the road, where we have several different + // implementations for local scheduler and cluster schedulers. + try { + SparkEnv.get.closureSerializer.newInstance().serialize(stage.rdd) + } catch { + case e: NotSerializableException => abortStage(stage, e.toString) + return + } + if (!waiting(stage) && !running(stage) && !failed(stage)) { val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala index 6c847b8fef8562886accfd13407aa110007a5277..c3c52f9118b78376e018ec460534752716cdc2a6 100644 --- a/core/src/test/scala/spark/FailureSuite.scala +++ b/core/src/test/scala/spark/FailureSuite.scala @@ -18,9 +18,6 @@ package spark import org.scalatest.FunSuite -import org.scalatest.prop.Checkers - -import scala.collection.mutable.ArrayBuffer import SparkContext._ @@ -40,7 +37,7 @@ object FailureSuiteState { } class FailureSuite extends FunSuite with LocalSparkContext { - + // Run a 3-task map job in which task 1 deterministically fails once, and check // whether the job completes successfully and we ran 4 tasks in total. test("failure in a single-stage job") { @@ -66,7 +63,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { test("failure in a two-stage job") { sc = new SparkContext("local[1,1]", "test") val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map { - case (k, v) => + case (k, v) => FailureSuiteState.synchronized { FailureSuiteState.tasksRun += 1 if (k == 1 && FailureSuiteState.tasksFailed == 0) { @@ -87,15 +84,33 @@ class FailureSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local[1,1]", "test") val results = sc.makeRDD(1 to 3).map(x => new NonSerializable) - val thrown = intercept[spark.SparkException] { + val thrown = intercept[SparkException] { results.collect() } - assert(thrown.getClass === classOf[spark.SparkException]) + assert(thrown.getClass === classOf[SparkException]) assert(thrown.getMessage.contains("NotSerializableException")) FailureSuiteState.clear() } + test("failure because task closure is not serializable") { + sc = new SparkContext("local[1,1]", "test") + val a = new NonSerializable + val thrown = intercept[SparkException] { + sc.parallelize(1 to 10, 2).map(x => a).count() + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getMessage.contains("NotSerializableException")) + + val thrown1 = intercept[SparkException] { + sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count() + } + assert(thrown1.getClass === classOf[SparkException]) + assert(thrown1.getMessage.contains("NotSerializableException")) + + FailureSuiteState.clear() + } + // TODO: Need to add tests with shuffle fetch failures. }