Skip to content
Snippets Groups Projects
Commit 85ab8114 authored by Reynold Xin's avatar Reynold Xin
Browse files

Moved non-serializable closure catching exception from submitStage to submitMissingTasks

parent f2422d4f
No related branches found
No related tags found
No related merge requests found
......@@ -465,18 +465,6 @@ class DAGScheduler(
/** Submits stage, but first recursively submits any missing parents. */
private def submitStage(stage: Stage) {
logDebug("submitStage(" + stage + ")")
// Preemptively serialize the stage RDD to make sure the tasks for this stage will be
// serializable. We are catching this exception here because it would be fairly hard to
// catch the non-serializable exception down the road, where we have several different
// implementations for local scheduler and cluster schedulers.
try {
SparkEnv.get.closureSerializer.newInstance().serialize(stage.rdd)
} catch {
case e: NotSerializableException => abortStage(stage, e.toString)
return
}
if (!waiting(stage) && !running(stage) && !failed(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
......@@ -515,6 +503,19 @@ class DAGScheduler(
}
}
if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this
// exception here because it would be fairly hard to catch the non-serializable exception
// down the road, where we have several different implementations for local scheduler and
// cluster schedulers.
try {
SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head)
} catch {
case e: NotSerializableException =>
abortStage(stage, e.toString)
running -= stage
return
}
sparkListeners.foreach(_.onStageSubmitted(SparkListenerStageSubmitted(stage, tasks.size)))
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
......
......@@ -96,18 +96,28 @@ class FailureSuite extends FunSuite with LocalSparkContext {
test("failure because task closure is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable
// Non-serializable closure in the final result stage
val thrown = intercept[SparkException] {
sc.parallelize(1 to 10, 2).map(x => a).count()
}
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))
// Non-serializable closure in an earlier stage
val thrown1 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
}
assert(thrown1.getClass === classOf[SparkException])
assert(thrown1.getMessage.contains("NotSerializableException"))
// Non-serializable closure in foreach function
val thrown2 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).foreach(x => println(a))
}
assert(thrown2.getClass === classOf[SparkException])
assert(thrown2.getMessage.contains("NotSerializableException"))
FailureSuiteState.clear()
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment