Skip to content
Snippets Groups Projects
Commit 876125b9 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #726 from rxin/spark-826

SPARK-829: scheduler shouldn't hang if a task contains unserializable objects in its closure
parents 6a31b719 85ab8114
No related branches found
No related tags found
No related merge requests found
......@@ -17,19 +17,17 @@
package spark.scheduler
import cluster.TaskInfo
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
import java.io.NotSerializableException
import java.util.Properties
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import spark._
import spark.executor.TaskMetrics
import spark.partial.ApproximateActionListener
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
import spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import spark.scheduler.cluster.TaskInfo
import spark.storage.{BlockManager, BlockManagerMaster}
import spark.util.{MetadataCleaner, TimeStampedHashMap}
......@@ -263,7 +261,8 @@ class DAGScheduler(
assert(partitions.size > 0)
val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter,
properties)
return (toSubmit, waiter)
}
......@@ -288,7 +287,7 @@ class DAGScheduler(
"Total number of partitions: " + maxPartitions)
}
val (toSubmit, waiter) = prepareJob(
val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob(
finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
eventQueue.put(toSubmit)
waiter.awaitResult() match {
......@@ -512,6 +511,19 @@ class DAGScheduler(
}
}
if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this
// exception here because it would be fairly hard to catch the non-serializable exception
// down the road, where we have several different implementations for local scheduler and
// cluster schedulers.
try {
SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head)
} catch {
case e: NotSerializableException =>
abortStage(stage, e.toString)
running -= stage
return
}
sparkListeners.foreach(_.onStageSubmitted(SparkListenerStageSubmitted(stage, tasks.size)))
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
......
......@@ -18,9 +18,6 @@
package spark
import org.scalatest.FunSuite
import org.scalatest.prop.Checkers
import scala.collection.mutable.ArrayBuffer
import SparkContext._
......@@ -40,7 +37,7 @@ object FailureSuiteState {
}
class FailureSuite extends FunSuite with LocalSparkContext {
// Run a 3-task map job in which task 1 deterministically fails once, and check
// whether the job completes successfully and we ran 4 tasks in total.
test("failure in a single-stage job") {
......@@ -66,7 +63,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
test("failure in a two-stage job") {
sc = new SparkContext("local[1,1]", "test")
val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map {
case (k, v) =>
case (k, v) =>
FailureSuiteState.synchronized {
FailureSuiteState.tasksRun += 1
if (k == 1 && FailureSuiteState.tasksFailed == 0) {
......@@ -87,12 +84,40 @@ class FailureSuite extends FunSuite with LocalSparkContext {
sc = new SparkContext("local[1,1]", "test")
val results = sc.makeRDD(1 to 3).map(x => new NonSerializable)
val thrown = intercept[spark.SparkException] {
val thrown = intercept[SparkException] {
results.collect()
}
assert(thrown.getClass === classOf[spark.SparkException])
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))
FailureSuiteState.clear()
}
test("failure because task closure is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable
// Non-serializable closure in the final result stage
val thrown = intercept[SparkException] {
sc.parallelize(1 to 10, 2).map(x => a).count()
}
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))
// Non-serializable closure in an earlier stage
val thrown1 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
}
assert(thrown1.getClass === classOf[SparkException])
assert(thrown1.getMessage.contains("NotSerializableException"))
// Non-serializable closure in foreach function
val thrown2 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).foreach(x => println(a))
}
assert(thrown2.getClass === classOf[SparkException])
assert(thrown2.getMessage.contains("NotSerializableException"))
FailureSuiteState.clear()
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment