From f6480b1467d0432fb2aa48c7a3a8a6e6679fd481 Mon Sep 17 00:00:00 2001 From: gmoehler <moehler@de.ibm.com> Date: Wed, 25 Jan 2017 08:17:24 -0800 Subject: [PATCH] [SPARK-19311][SQL] fix UDT hierarchy issue ## What changes were proposed in this pull request? acceptType() in UDT will no only accept the same type but also all base types ## How was this patch tested? Manual test using a set of generated UDTs fixing acceptType() in my user defined types Please review http://spark.apache.org/contributing.html before opening a pull request. Author: gmoehler <moehler@de.ibm.com> Closes #16660 from gmoehler/master. --- .../spark/sql/types/UserDefinedType.scala | 8 +- .../spark/sql/UserDefinedTypeSuite.scala | 105 +++++++++++++++++- 2 files changed, 110 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index c33219c95b..5a944e763e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -78,8 +78,12 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa */ override private[spark] def asNullable: UserDefinedType[UserType] = this - override private[sql] def acceptsType(dataType: DataType) = - this.getClass == dataType.getClass + override private[sql] def acceptsType(dataType: DataType) = dataType match { + case other: UserDefinedType[_] => + this.getClass == other.getClass || + this.userClass.isAssignableFrom(other.userClass) + case _ => false + } override def sql: String = sqlType.sql diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 474f17ff7a..ea4a8ee7ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ @@ -71,6 +72,77 @@ object UDT { } +// object and classes to test SPARK-19311 + +// Trait/Interface for base type +sealed trait IExampleBaseType extends Serializable { + def field: Int +} + +// Trait/Interface for derived type +sealed trait IExampleSubType extends IExampleBaseType + +// a base class +class ExampleBaseClass(override val field: Int) extends IExampleBaseType + +// a derived class +class ExampleSubClass(override val field: Int) + extends ExampleBaseClass(field) with IExampleSubType + +// UDT for base class +class ExampleBaseTypeUDT extends UserDefinedType[IExampleBaseType] { + + override def sqlType: StructType = { + StructType(Seq( + StructField("intfield", IntegerType, nullable = false))) + } + + override def serialize(obj: IExampleBaseType): InternalRow = { + val row = new GenericInternalRow(1) + row.setInt(0, obj.field) + row + } + + override def deserialize(datum: Any): IExampleBaseType = { + datum match { + case row: InternalRow => + require(row.numFields == 1, + "ExampleBaseTypeUDT requires row with length == 1") + val field = row.getInt(0) + new ExampleBaseClass(field) + } + } + + override def userClass: Class[IExampleBaseType] = classOf[IExampleBaseType] +} + +// UDT for derived class +private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] { + + override def sqlType: StructType = { + StructType(Seq( + StructField("intfield", IntegerType, nullable = false))) + } + + override def serialize(obj: IExampleSubType): InternalRow = { + val row = new GenericInternalRow(1) + row.setInt(0, obj.field) + row + } + + override def deserialize(datum: Any): IExampleSubType = { + datum match { + case row: InternalRow => + require(row.numFields == 1, + "ExampleSubTypeUDT requires row with length == 1") + val field = row.getInt(0) + new ExampleSubClass(field) + } + } + + override def userClass: Class[IExampleSubType] = classOf[IExampleSubType] +} + class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { import testImplicits._ @@ -194,4 +266,35 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT // call `collect` to make sure this query can pass analysis. pointsRDD.as[MyLabeledPoint].map(_.copy(label = 2.0)).collect() } + + test("SPARK-19311: UDFs disregard UDT type hierarchy") { + UDTRegistration.register(classOf[IExampleBaseType].getName, + classOf[ExampleBaseTypeUDT].getName) + UDTRegistration.register(classOf[IExampleSubType].getName, + classOf[ExampleSubTypeUDT].getName) + + // UDF that returns a base class object + sqlContext.udf.register("doUDF", (param: Int) => { + new ExampleBaseClass(param) + }: IExampleBaseType) + + // UDF that returns a derived class object + sqlContext.udf.register("doSubTypeUDF", (param: Int) => { + new ExampleSubClass(param) + }: IExampleSubType) + + // UDF that takes a base class object as parameter + sqlContext.udf.register("doOtherUDF", (obj: IExampleBaseType) => { + obj.field + }: Int) + + // this worked already before the fix SPARK-19311: + // return type of doUDF equals parameter type of doOtherUDF + sql("SELECT doOtherUDF(doUDF(41))") + + // this one passes only with the fix SPARK-19311: + // return type of doSubUDF is a subtype of the parameter type of doOtherUDF + sql("SELECT doOtherUDF(doSubTypeUDF(42))") + } + } -- GitLab