Skip to content
Snippets Groups Projects
Commit dfa61f7b authored by Shixiong Zhu's avatar Shixiong Zhu Committed by Michael Armbrust
Browse files

[SPARK-15190][SQL] Support using SQLUserDefinedType for case classes

## What changes were proposed in this pull request?

Right now inferring the schema for case classes happens before searching the SQLUserDefinedType annotation, so the SQLUserDefinedType annotation for case classes doesn't work.

This PR simply changes the inferring order to resolve it. I also reenabled the java.math.BigDecimal test and added two tests for `List`.

## How was this patch tested?

`encodeDecodeTest(UDTCaseClass(new java.net.URI("http://spark.apache.org/")), "udt with case class")`

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #12965 from zsxwing/SPARK-15190.
parent 22947cd0
No related branches found
No related tags found
No related merge requests found
......@@ -348,6 +348,23 @@ object ScalaReflection extends ScalaReflection {
"toScalaMap",
keyData :: valueData :: Nil)
case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
case t if UDTRegistration.exists(getClassNameFromType(t)) =>
val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
.asInstanceOf[UserDefinedType[_]]
val obj = NewInstance(
udt.getClass,
Nil,
dataType = ObjectType(udt.getClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
......@@ -388,23 +405,6 @@ object ScalaReflection extends ScalaReflection {
} else {
newInstance
}
case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
case t if UDTRegistration.exists(getClassNameFromType(t)) =>
val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
.asInstanceOf[UserDefinedType[_]]
val obj = NewInstance(
udt.getClass,
Nil,
dataType = ObjectType(udt.getClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
}
}
......@@ -522,17 +522,6 @@ object ScalaReflection extends ScalaReflection {
val TypeRef(_, _, Seq(elementType)) = t
toCatalystArray(inputObject, elementType)
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
val clsName = getClassNameFromType(fieldType)
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
})
val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
toCatalystArray(inputObject, elementType)
......@@ -645,6 +634,17 @@ object ScalaReflection extends ScalaReflection {
dataType = ObjectType(udt.getClass))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
val clsName = getClassNameFromType(fieldType)
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
})
val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
case other =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
......@@ -743,13 +743,6 @@ object ScalaReflection extends ScalaReflection {
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
Schema(StructType(
params.map { case (fieldName, fieldType) =>
val Schema(dataType, nullable) = schemaFor(fieldType)
StructField(fieldName, dataType, nullable)
}), nullable = true)
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true)
......@@ -775,6 +768,13 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
Schema(StructType(
params.map { case (fieldName, fieldType) =>
val Schema(dataType, nullable) = schemaFor(fieldType)
StructField(fieldName, dataType, nullable)
}), nullable = true)
case other =>
throw new UnsupportedOperationException(s"Schema for type $other is not supported")
}
......
......@@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
case class RepeatedStruct(s: Seq[PrimitiveData])
......@@ -86,6 +87,25 @@ class JavaSerializable(val value: Int) extends Serializable {
}
}
/** For testing UDT for a case class */
@SQLUserDefinedType(udt = classOf[UDTForCaseClass])
case class UDTCaseClass(uri: java.net.URI)
class UDTForCaseClass extends UserDefinedType[UDTCaseClass] {
override def sqlType: DataType = StringType
override def serialize(obj: UDTCaseClass): UTF8String = {
UTF8String.fromString(obj.uri.toString)
}
override def userClass: Class[UDTCaseClass] = classOf[UDTCaseClass]
override def deserialize(datum: Any): UDTCaseClass = datum match {
case uri: UTF8String => UDTCaseClass(new java.net.URI(uri.toString))
}
}
class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
OuterScopes.addOuterScope(this)
......@@ -147,6 +167,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple")
encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple")
encodeDecodeTest(List(1, 2), "list of int")
encodeDecodeTest(List("a", null), "list with String and null")
encodeDecodeTest(
UDTCaseClass(new java.net.URI("http://spark.apache.org/")), "udt with case class")
// Kryo encoders
encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String]))
encodeDecodeTest(new KryoSerializable(15), "kryo object")(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment