From 6847e93cf427aa971dac1ea261c1443eebf4089e Mon Sep 17 00:00:00 2001 From: Andrew Ash <andrew@andrewash.com> Date: Mon, 14 Aug 2017 22:48:08 +0800 Subject: [PATCH] [SPARK-21563][CORE] Fix race condition when serializing TaskDescriptions and adding jars ## What changes were proposed in this pull request? Fix the race condition when serializing TaskDescriptions and adding jars by keeping the set of jars and files for a TaskSet constant across the lifetime of the TaskSet. Otherwise TaskDescription serialization can produce an invalid serialization when new file/jars are added concurrently as the TaskDescription is serialized. ## How was this patch tested? Additional unit test ensures jars/files contained in the TaskDescription remain constant throughout the lifetime of the TaskSet. Author: Andrew Ash <andrew@andrewash.com> Closes #18913 from ash211/SPARK-21563. --- .../scala/org/apache/spark/SparkContext.scala | 7 ++++ .../spark/scheduler/TaskSetManager.scala | 8 +++-- .../spark/scheduler/TaskSetManagerSuite.scala | 34 ++++++++++++++++++- 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5316468914..136f0af7b2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1490,6 +1490,8 @@ class SparkContext(config: SparkConf) extends Logging { /** * Add a file to be downloaded with this Spark job on every node. * + * If a file is added during execution, it will not be available until the next TaskSet starts. + * * @param path can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. @@ -1506,6 +1508,8 @@ class SparkContext(config: SparkConf) extends Logging { /** * Add a file to be downloaded with this Spark job on every node. * + * If a file is added during execution, it will not be available until the next TaskSet starts. + * * @param path can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. @@ -1792,6 +1796,9 @@ class SparkContext(config: SparkConf) extends Logging { /** * Adds a JAR dependency for all tasks to be executed on this `SparkContext` in the future. + * + * If a jar is added during execution, it will not be available until the next TaskSet starts. + * * @param path can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 589fe672ad..c2510714e1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -56,6 +56,10 @@ private[spark] class TaskSetManager( private val conf = sched.sc.conf + // SPARK-21563 make a copy of the jars/files so they are consistent across the TaskSet + private val addedJars = HashMap[String, Long](sched.sc.addedJars.toSeq: _*) + private val addedFiles = HashMap[String, Long](sched.sc.addedFiles.toSeq: _*) + // Quantile of tasks at which to start speculation val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) @@ -502,8 +506,8 @@ private[spark] class TaskSetManager( execId, taskName, index, - sched.sc.addedFiles, - sched.sc.addedJars, + addedFiles, + addedJars, task.localProperties, serializedTask) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index e46900e4e5..3696df06e0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.config import org.apache.spark.internal.Logging import org.apache.spark.serializer.SerializerInstance import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{AccumulatorV2, ManualClock} +import org.apache.spark.util.{AccumulatorV2, ManualClock, Utils} class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -1214,6 +1214,38 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg verify(taskSetManagerSpy, times(1)).addPendingTask(anyInt()) } + test("SPARK-21563 context's added jars shouldn't change mid-TaskSet") { + sc = new SparkContext("local", "test") + val addedJarsPreTaskSet = Map[String, Long](sc.addedJars.toSeq: _*) + assert(addedJarsPreTaskSet.size === 0) + + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet1 = FakeTask.createTaskSet(3) + val manager1 = new TaskSetManager(sched, taskSet1, MAX_TASK_FAILURES, clock = new ManualClock) + + // all tasks from the first taskset have the same jars + val taskOption1 = manager1.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption1.get.addedJars === addedJarsPreTaskSet) + val taskOption2 = manager1.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption2.get.addedJars === addedJarsPreTaskSet) + + // even with a jar added mid-TaskSet + val jarPath = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar") + sc.addJar(jarPath.toString) + val addedJarsMidTaskSet = Map[String, Long](sc.addedJars.toSeq: _*) + assert(addedJarsPreTaskSet !== addedJarsMidTaskSet) + val taskOption3 = manager1.resourceOffer("exec1", "host1", NO_PREF) + // which should have the old version of the jars list + assert(taskOption3.get.addedJars === addedJarsPreTaskSet) + + // and then the jar does appear in the next TaskSet + val taskSet2 = FakeTask.createTaskSet(1) + val manager2 = new TaskSetManager(sched, taskSet2, MAX_TASK_FAILURES, clock = new ManualClock) + + val taskOption4 = manager2.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption4.get.addedJars === addedJarsMidTaskSet) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { -- GitLab