Skip to content
Snippets Groups Projects
Commit f6480b14 authored by gmoehler's avatar gmoehler Committed by gatorsmile
Browse files

[SPARK-19311][SQL] fix UDT hierarchy issue

## What changes were proposed in this pull request?
acceptType() in UDT will no only accept the same type but also all base types

## How was this patch tested?
Manual test using a set of generated UDTs fixing acceptType() in my user defined types

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: gmoehler <moehler@de.ibm.com>

Closes #16660 from gmoehler/master.
parent f1ddca5f
No related branches found
No related tags found
No related merge requests found
......@@ -78,8 +78,12 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa
*/
override private[spark] def asNullable: UserDefinedType[UserType] = this
override private[sql] def acceptsType(dataType: DataType) =
this.getClass == dataType.getClass
override private[sql] def acceptsType(dataType: DataType) = dataType match {
case other: UserDefinedType[_] =>
this.getClass == other.getClass ||
this.userClass.isAssignableFrom(other.userClass)
case _ => false
}
override def sql: String = sqlType.sql
......
......@@ -20,7 +20,8 @@ package org.apache.spark.sql
import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
import org.apache.spark.sql.functions._
......@@ -71,6 +72,77 @@ object UDT {
}
// object and classes to test SPARK-19311
// Trait/Interface for base type
sealed trait IExampleBaseType extends Serializable {
def field: Int
}
// Trait/Interface for derived type
sealed trait IExampleSubType extends IExampleBaseType
// a base class
class ExampleBaseClass(override val field: Int) extends IExampleBaseType
// a derived class
class ExampleSubClass(override val field: Int)
extends ExampleBaseClass(field) with IExampleSubType
// UDT for base class
class ExampleBaseTypeUDT extends UserDefinedType[IExampleBaseType] {
override def sqlType: StructType = {
StructType(Seq(
StructField("intfield", IntegerType, nullable = false)))
}
override def serialize(obj: IExampleBaseType): InternalRow = {
val row = new GenericInternalRow(1)
row.setInt(0, obj.field)
row
}
override def deserialize(datum: Any): IExampleBaseType = {
datum match {
case row: InternalRow =>
require(row.numFields == 1,
"ExampleBaseTypeUDT requires row with length == 1")
val field = row.getInt(0)
new ExampleBaseClass(field)
}
}
override def userClass: Class[IExampleBaseType] = classOf[IExampleBaseType]
}
// UDT for derived class
private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] {
override def sqlType: StructType = {
StructType(Seq(
StructField("intfield", IntegerType, nullable = false)))
}
override def serialize(obj: IExampleSubType): InternalRow = {
val row = new GenericInternalRow(1)
row.setInt(0, obj.field)
row
}
override def deserialize(datum: Any): IExampleSubType = {
datum match {
case row: InternalRow =>
require(row.numFields == 1,
"ExampleSubTypeUDT requires row with length == 1")
val field = row.getInt(0)
new ExampleSubClass(field)
}
}
override def userClass: Class[IExampleSubType] = classOf[IExampleSubType]
}
class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest {
import testImplicits._
......@@ -194,4 +266,35 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
// call `collect` to make sure this query can pass analysis.
pointsRDD.as[MyLabeledPoint].map(_.copy(label = 2.0)).collect()
}
test("SPARK-19311: UDFs disregard UDT type hierarchy") {
UDTRegistration.register(classOf[IExampleBaseType].getName,
classOf[ExampleBaseTypeUDT].getName)
UDTRegistration.register(classOf[IExampleSubType].getName,
classOf[ExampleSubTypeUDT].getName)
// UDF that returns a base class object
sqlContext.udf.register("doUDF", (param: Int) => {
new ExampleBaseClass(param)
}: IExampleBaseType)
// UDF that returns a derived class object
sqlContext.udf.register("doSubTypeUDF", (param: Int) => {
new ExampleSubClass(param)
}: IExampleSubType)
// UDF that takes a base class object as parameter
sqlContext.udf.register("doOtherUDF", (obj: IExampleBaseType) => {
obj.field
}: Int)
// this worked already before the fix SPARK-19311:
// return type of doUDF equals parameter type of doOtherUDF
sql("SELECT doOtherUDF(doUDF(41))")
// this one passes only with the fix SPARK-19311:
// return type of doSubUDF is a subtype of the parameter type of doOtherUDF
sql("SELECT doOtherUDF(doSubTypeUDF(42))")
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment