From 3cca5ffb3d60d5de9a54bc71cf0b8279898936d2 Mon Sep 17 00:00:00 2001 From: Hurshal Patel <hpatel516@gmail.com> Date: Wed, 18 Nov 2015 09:28:59 -0800 Subject: [PATCH] [SPARK-11195][CORE] Use correct classloader for TaskResultGetter Make sure we are using the context classloader when deserializing failed TaskResults instead of the Spark classloader. The issue is that `enqueueFailedTask` was using the incorrect classloader which results in `ClassNotFoundException`. Adds a test in TaskResultGetterSuite that compiles a custom exception, throws it on the executor, and asserts that Spark handles the TaskResult deserialization instead of returning `UnknownReason`. See #9367 for previous comments See SPARK-11195 for a full repro Author: Hurshal Patel <hpatel516@gmail.com> Closes #9779 from choochootrain/spark-11195-master. --- .../scala/org/apache/spark/TestUtils.scala | 11 ++-- .../spark/scheduler/TaskResultGetter.scala | 4 +- .../scheduler/TaskResultGetterSuite.scala | 65 ++++++++++++++++++- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index acfe751f6c..43c89b258f 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.net.{URI, URL} import java.nio.charset.StandardCharsets +import java.nio.file.Paths import java.util.Arrays import java.util.jar.{JarEntry, JarOutputStream} @@ -83,15 +84,15 @@ private[spark] object TestUtils { } /** - * Create a jar file that contains this set of files. All files will be located at the root - * of the jar. + * Create a jar file that contains this set of files. All files will be located in the specified + * directory or at the root of the jar. */ - def createJar(files: Seq[File], jarFile: File): URL = { + def createJar(files: Seq[File], jarFile: File, directoryPrefix: Option[String] = None): URL = { val jarFileStream = new FileOutputStream(jarFile) val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) for (file <- files) { - val jarEntry = new JarEntry(file.getName) + val jarEntry = new JarEntry(Paths.get(directoryPrefix.getOrElse(""), file.getName).toString) jarStream.putNextEntry(jarEntry) val in = new FileInputStream(file) @@ -123,7 +124,7 @@ private[spark] object TestUtils { classpathUrls: Seq[URL]): File = { val compiler = ToolProvider.getSystemJavaCompiler - // Calling this outputs a class file in pwd. It's easier to just rename the file than + // Calling this outputs a class file in pwd. It's easier to just rename the files than // build a custom FileManager that controls the output location. val options = if (classpathUrls.nonEmpty) { Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 46a6f6537e..f4965994d8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -103,16 +103,16 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul try { getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { + val loader = Utils.getContextOrSparkClassLoader try { if (serializedData != null && serializedData.limit() > 0) { reason = serializer.get().deserialize[TaskEndReason]( - serializedData, Utils.getSparkClassLoader) + serializedData, loader) } } catch { case cnd: ClassNotFoundException => // Log an error but keep going here -- the task failed, so not catastrophic // if we can't deserialize the reason. - val loader = Utils.getContextOrSparkClassLoader logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) case ex: Exception => {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 815caa79ff..bc72c3685e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.io.File +import java.net.URL import java.nio.ByteBuffer import scala.concurrent.duration._ @@ -26,8 +28,10 @@ import scala.util.control.NonFatal import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.storage.TaskResultBlockId +import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.util.{MutableURLClassLoader, Utils} /** * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. @@ -119,5 +123,64 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local // Make sure two tasks were run (one failed one, and a second retried one). assert(scheduler.nextTaskId.get() === 2) } + + /** + * Make sure we are using the context classloader when deserializing failed TaskResults instead + * of the Spark classloader. + + * This test compiles a jar containing an exception and tests that when it is thrown on the + * executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown + * exception as the cause. + + * Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing + * the exception, resulting in an UnknownReason for the TaskEndResult. + */ + test("failed task deserialized with the correct classloader (SPARK-11195)") { + // compile a small jar containing an exception that will be thrown on an executor. + val tempDir = Utils.createTempDir() + val srcDir = new File(tempDir, "repro/") + srcDir.mkdirs() + val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath, + """package repro; + | + |public class MyException extends Exception { + |} + """.stripMargin) + val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty) + val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) + TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro")) + + // ensure we reset the classloader after the test completes + val originalClassLoader = Thread.currentThread.getContextClassLoader + try { + // load the exception from the jar + val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) + loader.addURL(jarFile.toURI.toURL) + Thread.currentThread().setContextClassLoader(loader) + val excClass: Class[_] = Utils.classForName("repro.MyException") + + // NOTE: we must run the cluster with "local" so that the executor can load the compiled + // jar. + sc = new SparkContext("local", "test", conf) + val rdd = sc.parallelize(Seq(1), 1).map { _ => + val exc = excClass.newInstance().asInstanceOf[Exception] + throw exc + } + + // the driver should not have any problems resolving the exception class and determining + // why the task failed. + val exceptionMessage = intercept[SparkException] { + rdd.collect() + }.getMessage + + val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r + val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r + + assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined) + assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty) + } finally { + Thread.currentThread.setContextClassLoader(originalClassLoader) + } + } } -- GitLab