Skip to content
Snippets Groups Projects
Commit 3cca5ffb authored by Hurshal Patel's avatar Hurshal Patel Committed by Yin Huai
Browse files

[SPARK-11195][CORE] Use correct classloader for TaskResultGetter

Make sure we are using the context classloader when deserializing failed TaskResults instead of the Spark classloader.

The issue is that `enqueueFailedTask` was using the incorrect classloader which results in `ClassNotFoundException`.

Adds a test in TaskResultGetterSuite that compiles a custom exception, throws it on the executor, and asserts that Spark handles the TaskResult deserialization instead of returning `UnknownReason`.

See #9367 for previous comments
See SPARK-11195 for a full repro

Author: Hurshal Patel <hpatel516@gmail.com>

Closes #9779 from choochootrain/spark-11195-master.
parent 224723e6
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,7 @@ package org.apache.spark
import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
import java.net.{URI, URL}
import java.nio.charset.StandardCharsets
import java.nio.file.Paths
import java.util.Arrays
import java.util.jar.{JarEntry, JarOutputStream}
......@@ -83,15 +84,15 @@ private[spark] object TestUtils {
}
/**
* Create a jar file that contains this set of files. All files will be located at the root
* of the jar.
* Create a jar file that contains this set of files. All files will be located in the specified
* directory or at the root of the jar.
*/
def createJar(files: Seq[File], jarFile: File): URL = {
def createJar(files: Seq[File], jarFile: File, directoryPrefix: Option[String] = None): URL = {
val jarFileStream = new FileOutputStream(jarFile)
val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest())
for (file <- files) {
val jarEntry = new JarEntry(file.getName)
val jarEntry = new JarEntry(Paths.get(directoryPrefix.getOrElse(""), file.getName).toString)
jarStream.putNextEntry(jarEntry)
val in = new FileInputStream(file)
......@@ -123,7 +124,7 @@ private[spark] object TestUtils {
classpathUrls: Seq[URL]): File = {
val compiler = ToolProvider.getSystemJavaCompiler
// Calling this outputs a class file in pwd. It's easier to just rename the file than
// Calling this outputs a class file in pwd. It's easier to just rename the files than
// build a custom FileManager that controls the output location.
val options = if (classpathUrls.nonEmpty) {
Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator))
......
......@@ -103,16 +103,16 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
try {
getTaskResultExecutor.execute(new Runnable {
override def run(): Unit = Utils.logUncaughtExceptions {
val loader = Utils.getContextOrSparkClassLoader
try {
if (serializedData != null && serializedData.limit() > 0) {
reason = serializer.get().deserialize[TaskEndReason](
serializedData, Utils.getSparkClassLoader)
serializedData, loader)
}
} catch {
case cnd: ClassNotFoundException =>
// Log an error but keep going here -- the task failed, so not catastrophic
// if we can't deserialize the reason.
val loader = Utils.getContextOrSparkClassLoader
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
case ex: Exception => {}
......
......@@ -17,6 +17,8 @@
package org.apache.spark.scheduler
import java.io.File
import java.net.URL
import java.nio.ByteBuffer
import scala.concurrent.duration._
......@@ -26,8 +28,10 @@ import scala.util.control.NonFatal
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite}
import org.apache.spark._
import org.apache.spark.storage.TaskResultBlockId
import org.apache.spark.TestUtils.JavaSourceFromString
import org.apache.spark.util.{MutableURLClassLoader, Utils}
/**
* Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
......@@ -119,5 +123,64 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
// Make sure two tasks were run (one failed one, and a second retried one).
assert(scheduler.nextTaskId.get() === 2)
}
/**
* Make sure we are using the context classloader when deserializing failed TaskResults instead
* of the Spark classloader.
* This test compiles a jar containing an exception and tests that when it is thrown on the
* executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown
* exception as the cause.
* Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing
* the exception, resulting in an UnknownReason for the TaskEndResult.
*/
test("failed task deserialized with the correct classloader (SPARK-11195)") {
// compile a small jar containing an exception that will be thrown on an executor.
val tempDir = Utils.createTempDir()
val srcDir = new File(tempDir, "repro/")
srcDir.mkdirs()
val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath,
"""package repro;
|
|public class MyException extends Exception {
|}
""".stripMargin)
val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty)
val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis()))
TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro"))
// ensure we reset the classloader after the test completes
val originalClassLoader = Thread.currentThread.getContextClassLoader
try {
// load the exception from the jar
val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader)
loader.addURL(jarFile.toURI.toURL)
Thread.currentThread().setContextClassLoader(loader)
val excClass: Class[_] = Utils.classForName("repro.MyException")
// NOTE: we must run the cluster with "local" so that the executor can load the compiled
// jar.
sc = new SparkContext("local", "test", conf)
val rdd = sc.parallelize(Seq(1), 1).map { _ =>
val exc = excClass.newInstance().asInstanceOf[Exception]
throw exc
}
// the driver should not have any problems resolving the exception class and determining
// why the task failed.
val exceptionMessage = intercept[SparkException] {
rdd.collect()
}.getMessage
val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r
val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r
assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined)
assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty)
} finally {
Thread.currentThread.setContextClassLoader(originalClassLoader)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment