Skip to content
Snippets Groups Projects
Commit 127a6678 authored by Sela's avatar Sela Committed by Michael Armbrust
Browse files

[SPARK-15489][SQL] Dataset kryo encoder won't load custom user settings

## What changes were proposed in this pull request?

Serializer instantiation will consider existing SparkConf

## How was this patch tested?
manual test with `ImmutableList` (Guava) and `kryo-serializers`'s `Immutable*Serializer` implementations.

Added Test Suite.

(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)

Author: Sela <ansela@paypal.com>

Closes #13424 from amitsela/SPARK-15489.
parent aec502d9
No related branches found
No related tags found
No related merge requests found
...@@ -22,7 +22,7 @@ import java.lang.reflect.Modifier ...@@ -22,7 +22,7 @@ import java.lang.reflect.Modifier
import scala.language.existentials import scala.language.existentials
import scala.reflect.ClassTag import scala.reflect.ClassTag
import org.apache.spark.SparkConf import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer._ import org.apache.spark.serializer._
import org.apache.spark.sql.Row import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
...@@ -547,11 +547,17 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) ...@@ -547,11 +547,17 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
(classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
} }
} }
// try conf from env, otherwise create a new one
val env = s"${classOf[SparkEnv].getName}.get()"
val sparkConf = s"new ${classOf[SparkConf].getName}()" val sparkConf = s"new ${classOf[SparkConf].getName}()"
ctx.addMutableState( val serializerInit = s"""
serializerInstanceClass, if ($env == null) {
serializer, $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") } else {
$serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
}
"""
ctx.addMutableState(serializerInstanceClass, serializer, serializerInit)
// Code to serialize. // Code to serialize.
val input = child.genCode(ctx) val input = child.genCode(ctx)
...@@ -587,11 +593,17 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B ...@@ -587,11 +593,17 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
(classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
} }
} }
// try conf from env, otherwise create a new one
val env = s"${classOf[SparkEnv].getName}.get()"
val sparkConf = s"new ${classOf[SparkConf].getName}()" val sparkConf = s"new ${classOf[SparkConf].getName}()"
ctx.addMutableState( val serializerInit = s"""
serializerInstanceClass, if ($env == null) {
serializer, $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") } else {
$serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
}
"""
ctx.addMutableState(serializerInstanceClass, serializer, serializerInit)
// Code to deserialize. // Code to deserialize.
val input = child.genCode(ctx) val input = child.genCode(ctx)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import com.esotericsoftware.kryo.{Kryo, Serializer}
import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.spark.serializer.KryoRegistrator
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.TestSparkSession
/**
* Test suite to test Kryo custom registrators.
*/
class DatasetSerializerRegistratorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
/**
* Initialize the [[TestSparkSession]] with a [[KryoRegistrator]].
*/
protected override def beforeAll(): Unit = {
sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName)
super.beforeAll()
}
test("Kryo registrator") {
implicit val kryoEncoder = Encoders.kryo[KryoData]
val ds = Seq(KryoData(1), KryoData(2)).toDS()
assert(ds.collect().toSet == Set(KryoData(0), KryoData(0)))
}
}
/** Used to test user provided registrator. */
class TestRegistrator extends KryoRegistrator {
override def registerClasses(kryo: Kryo): Unit =
kryo.register(classOf[KryoData], new ZeroKryoDataSerializer())
}
object TestRegistrator {
def apply(): TestRegistrator = new TestRegistrator()
}
/** A [[Serializer]] that takes a [[KryoData]] and serializes it as KryoData(0). */
class ZeroKryoDataSerializer extends Serializer[KryoData] {
override def write(kryo: Kryo, output: Output, t: KryoData): Unit = {
output.writeInt(0)
}
override def read(kryo: Kryo, input: Input, aClass: Class[KryoData]): KryoData = {
KryoData(input.readInt())
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment