Skip to content
Snippets Groups Projects
Commit cc364877 authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-3046] use executor's class loader as the default serializer classloader

The serializer is not always used in an executor thread (e.g. connection manager, broadcast), in which case the classloader might not have the user jar set, leading to corruption in deserialization.

https://issues.apache.org/jira/browse/SPARK-3046

https://issues.apache.org/jira/browse/SPARK-2878

Author: Reynold Xin <rxin@apache.org>

Closes #1972 from rxin/kryoBug and squashes the following commits:

c1c7bf0 [Reynold Xin] Made change to JavaSerializer.
7204c33 [Reynold Xin] Added imports back.
d879e67 [Reynold Xin] [SPARK-3046] use executor's class loader as the default serializer class loader.
parent c7032290
No related branches found
No related tags found
No related merge requests found
......@@ -99,6 +99,9 @@ private[spark] class Executor(
private val urlClassLoader = createClassLoader()
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
// Set the classloader for serializer
env.serializer.setDefaultClassLoader(urlClassLoader)
// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
......
......@@ -63,7 +63,9 @@ extends DeserializationStream {
def close() { objIn.close() }
}
private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader)
extends SerializerInstance {
def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
......@@ -109,7 +111,10 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize
class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)
def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset)
override def newInstance(): SerializerInstance = {
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
new JavaSerializerInstance(counterReset, classLoader)
}
override def writeExternal(out: ObjectOutput) {
out.writeInt(counterReset)
......
......@@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf)
val instantiator = new EmptyScalaKryoInstantiator
val kryo = instantiator.newKryo()
kryo.setRegistrationRequired(registrationRequired)
val classLoader = Thread.currentThread.getContextClassLoader
val oldClassLoader = Thread.currentThread.getContextClassLoader
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
// Do this before we invoke the user registrator so the user registrator can override this.
......@@ -84,10 +86,15 @@ class KryoSerializer(conf: SparkConf)
try {
val reg = Class.forName(regCls, true, classLoader).newInstance()
.asInstanceOf[KryoRegistrator]
// Use the default classloader when calling the user registrator.
Thread.currentThread.setContextClassLoader(classLoader)
reg.registerClasses(kryo)
} catch {
case e: Exception =>
throw new SparkException(s"Failed to invoke $regCls", e)
} finally {
Thread.currentThread.setContextClassLoader(oldClassLoader)
}
}
......
......@@ -44,6 +44,23 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
*/
@DeveloperApi
trait Serializer {
/**
* Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should
* make sure it is using this when set.
*/
@volatile protected var defaultClassLoader: Option[ClassLoader] = None
/**
* Sets a class loader for the serializer to use in deserialization.
*
* @return this Serializer object
*/
def setDefaultClassLoader(classLoader: ClassLoader): Serializer = {
defaultClassLoader = Some(classLoader)
this
}
def newInstance(): SerializerInstance
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.serializer
import org.apache.spark.util.Utils
import com.esotericsoftware.kryo.Kryo
import org.scalatest.FunSuite
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils}
import org.apache.spark.SparkContext._
import org.apache.spark.serializer.KryoDistributedTest._
class KryoSerializerDistributedSuite extends FunSuite {
test("kryo objects are serialised consistently in different processes") {
val conf = new SparkConf(false)
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName)
conf.set("spark.task.maxFailures", "1")
val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName))
conf.setJars(List(jar.getPath))
val sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
val original = Thread.currentThread.getContextClassLoader
val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
SparkEnv.get.serializer.setDefaultClassLoader(loader)
val cachedRDD = sc.parallelize((0 until 10).map((_, new MyCustomClass)), 3).cache()
// Randomly mix the keys so that the join below will require a shuffle with each partition
// sending data to multiple other partitions.
val shuffledRDD = cachedRDD.map { case (i, o) => (i * i * i - 10 * i * i, o)}
// Join the two RDDs, and force evaluation
assert(shuffledRDD.join(cachedRDD).collect().size == 1)
LocalSparkContext.stop(sc)
}
}
object KryoDistributedTest {
class MyCustomClass
class AppJarRegistrator extends KryoRegistrator {
override def registerClasses(k: Kryo) {
val classLoader = Thread.currentThread.getContextClassLoader
k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader))
}
}
object AppJarRegistrator {
val customClassName = "KryoSerializerDistributedSuiteCustomClass"
}
}
......@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
import com.esotericsoftware.kryo.Kryo
import org.scalatest.FunSuite
import org.apache.spark.SharedSparkContext
import org.apache.spark.{SparkConf, SharedSparkContext}
import org.apache.spark.serializer.KryoTest._
class KryoSerializerSuite extends FunSuite with SharedSparkContext {
......@@ -217,8 +217,29 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance())
assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist"))
}
test("default class loader can be set by a different thread") {
val ser = new KryoSerializer(new SparkConf)
// First serialize the object
val serInstance = ser.newInstance()
val bytes = serInstance.serialize(new ClassLoaderTestingObject)
// Deserialize the object to make sure normal deserialization works
serInstance.deserialize[ClassLoaderTestingObject](bytes)
// Set a special, broken ClassLoader and make sure we get an exception on deserialization
ser.setDefaultClassLoader(new ClassLoader() {
override def loadClass(name: String) = throw new UnsupportedOperationException
})
intercept[UnsupportedOperationException] {
ser.newInstance().deserialize[ClassLoaderTestingObject](bytes)
}
}
}
class ClassLoaderTestingObject
class KryoSerializerResizableOutputSuite extends FunSuite {
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment