From 3ae25f244bd471ef77002c703f2cc7ed6b524f11 Mon Sep 17 00:00:00 2001 From: Joan <joan@goyeau.com> Date: Tue, 19 Apr 2016 17:36:31 -0700 Subject: [PATCH] [SPARK-13929] Use Scala reflection for UDTs ## What changes were proposed in this pull request? Enable ScalaReflection and User Defined Types for plain Scala classes. This involves the move of `schemaFor` from `ScalaReflection` trait (which is Runtime and Compile time (macros) reflection) to the `ScalaReflection` object (runtime reflection only) as I believe this code wouldn't work at compile time anyway as it manipulates `Class`'s that are not compiled yet. ## How was this patch tested? Unit test Author: Joan <joan@goyeau.com> Closes #12149 from joan38/SPARK-13929-Scala-reflection. --- .../spark/sql/types/SQLUserDefinedType.java | 5 - .../spark/sql/catalyst/ScalaReflection.scala | 98 ++++++-------- .../sql/catalyst/ScalaReflectionSuite.scala | 37 +++++- .../spark/sql/UserDefinedTypeSuite.scala | 123 +++++++++++------- .../datasources/json/JsonSuite.scala | 4 +- .../execution/AggregationQuerySuite.scala | 2 +- .../sql/sources/hadoopFsRelationSuites.scala | 2 +- 7 files changed, 157 insertions(+), 114 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index 1e4e5ede8c..110ed460cc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -24,11 +24,6 @@ import org.apache.spark.annotation.DeveloperApi; /** * ::DeveloperApi:: * A user-defined type which can be automatically recognized by a SQLContext and registered. - * <p> - * WARNING: This annotation will only work if both Java and Scala reflection return the same class - * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class - * is enclosed in an object (a singleton). - * <p> * WARNING: UDTs are currently only supported from Scala. */ // TODO: Should I used @Documented ? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4795fc2557..bd723135b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -374,10 +374,8 @@ object ScalaReflection extends ScalaReflection { newInstance } - case t if Utils.classIsLoadable(className) && - Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => - val udt = Utils.classForName(className) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, @@ -432,7 +430,6 @@ object ScalaReflection extends ScalaReflection { if (!inputObject.dataType.isInstanceOf[ObjectType]) { inputObject } else { - val className = getClassNameFromType(tpe) tpe match { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -589,9 +586,8 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) - case t if Utils.classIsLoadable(className) && - Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => - val udt = Utils.classForName(className) + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t) .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), @@ -637,24 +633,6 @@ object ScalaReflection extends ScalaReflection { * Retrieves the runtime class corresponding to the provided type. */ def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) -} - -/** - * Support for generating catalyst schemas for scala objects. Note that unlike its companion - * object, this trait able to work in both the runtime and the compile time (macro) universe. - */ -trait ScalaReflection { - /** The universe we work in (runtime or macro) */ - val universe: scala.reflect.api.Universe - - /** The mirror used to access types in the universe */ - def mirror: universe.Mirror - - import universe._ - - // The Predef.Map is scala.collection.immutable.Map. - // Since the map values can be mutable, we explicitly import scala.collection.Map at here. - import scala.collection.Map case class Schema(dataType: DataType, nullable: Boolean) @@ -668,36 +646,22 @@ trait ScalaReflection { def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) /** - * Return the Scala Type for `T` in the current classloader mirror. - * - * Use this method instead of the convenience method `universe.typeOf`, which - * assumes that all types can be found in the classloader that loaded scala-reflect classes. - * That's not necessarily the case when running using Eclipse launchers or even - * Sbt console or test (without `fork := true`). + * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. * - * @see SPARK-5281 + * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return + * `NullType` silently instead. */ - // SPARK-13640: Synchronize this because TypeTag.tpe is not thread-safe in Scala 2.10. - def localTypeOf[T: TypeTag]: `Type` = ScalaReflectionLock.synchronized { - val tag = implicitly[TypeTag[T]] - tag.in(mirror).tpe.normalize + def silentSchemaFor(tpe: `Type`): Schema = try { + schemaFor(tpe) + } catch { + case _: UnsupportedOperationException => Schema(NullType, nullable = true) } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { - val className = getClassNameFromType(tpe) - tpe match { - - case t if Utils.classIsLoadable(className) && - Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => - - // Note: We check for classIsLoadable above since Utils.classForName uses Java reflection, - // whereas className is from Scala reflection. This can make it hard to find classes - // in some cases, such as when a class is enclosed in an object (in which case - // Java appends a '$' to the object name but Scala does not). - val udt = Utils.classForName(className) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() Schema(udt, nullable = true) case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -748,17 +712,39 @@ trait ScalaReflection { throw new UnsupportedOperationException(s"Schema for type $other is not supported") } } +} + +/** + * Support for generating catalyst schemas for scala objects. Note that unlike its companion + * object, this trait able to work in both the runtime and the compile time (macro) universe. + */ +trait ScalaReflection { + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + /** The mirror used to access types in the universe */ + def mirror: universe.Mirror + + import universe._ + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map /** - * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. + * Return the Scala Type for `T` in the current classloader mirror. * - * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return - * `NullType` silently instead. + * Use this method instead of the convenience method `universe.typeOf`, which + * assumes that all types can be found in the classloader that loaded scala-reflect classes. + * That's not necessarily the case when running using Eclipse launchers or even + * Sbt console or test (without `fork := true`). + * + * @see SPARK-5281 */ - def silentSchemaFor(tpe: `Type`): Schema = try { - schemaFor(tpe) - } catch { - case _: UnsupportedOperationException => Schema(NullType, nullable = true) + // SPARK-13640: Synchronize this because TypeTag.tpe is not thread-safe in Scala 2.10. + def localTypeOf[T: TypeTag]: `Type` = ScalaReflectionLock.synchronized { + val tag = implicitly[TypeTag[T]] + tag.in(mirror).tpe.normalize } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 5ca5a72512..0672551b29 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -23,7 +23,7 @@ import java.sql.{Date, Timestamp} import scala.reflect.runtime.universe.typeOf import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.expressions.{BoundReference, SpecificMutableRow} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -81,9 +81,44 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } +object TestingUDT { + @SQLUserDefinedType(udt = classOf[NestedStructUDT]) + class NestedStruct(val a: Integer, val b: Long, val c: Double) + + class NestedStructUDT extends UserDefinedType[NestedStruct] { + override def sqlType: DataType = new StructType() + .add("a", IntegerType, nullable = true) + .add("b", LongType, nullable = false) + .add("c", DoubleType, nullable = false) + + override def serialize(n: NestedStruct): Any = { + val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType)) + row.setInt(0, n.a) + row.setLong(1, n.b) + row.setDouble(2, n.c) + } + + override def userClass: Class[NestedStruct] = classOf[NestedStruct] + + override def deserialize(datum: Any): NestedStruct = datum match { + case row: InternalRow => + new NestedStruct(row.getInt(0), row.getLong(1), row.getDouble(2)) + } + } +} + + class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ + test("SQLUserDefinedType annotation on Scala structure") { + val schema = schemaFor[TestingUDT.NestedStruct] + assert(schema === Schema( + new TestingUDT.NestedStructUDT, + nullable = true + )) + } + test("primitive data") { val schema = schemaFor[PrimitiveData] assert(schema === Schema( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 8c4afb605b..acc9f48d7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -27,51 +27,56 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) -private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { - override def equals(other: Any): Boolean = other match { - case v: MyDenseVector => - java.util.Arrays.equals(this.data, v.data) - case _ => false - } -} - @BeanInfo private[sql] case class MyLabeledPoint( - @BeanProperty label: Double, - @BeanProperty features: MyDenseVector) + @BeanProperty label: Double, + @BeanProperty features: UDT.MyDenseVector) + +// Wrapped in an object to check Scala compatibility. See SPARK-13929 +object UDT { + + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) + private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { + override def equals(other: Any): Boolean = other match { + case v: MyDenseVector => + java.util.Arrays.equals(this.data, v.data) + case _ => false + } + } -private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { + private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { - override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(features: MyDenseVector): ArrayData = { - new GenericArrayData(features.data.map(_.asInstanceOf[Any])) - } + override def serialize(features: MyDenseVector): ArrayData = { + new GenericArrayData(features.data.map(_.asInstanceOf[Any])) + } - override def deserialize(datum: Any): MyDenseVector = { - datum match { - case data: ArrayData => - new MyDenseVector(data.toDoubleArray()) + override def deserialize(datum: Any): MyDenseVector = { + datum match { + case data: ArrayData => + new MyDenseVector(data.toDoubleArray()) + } } - } - override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] + override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] - private[spark] override def asNullable: MyDenseVectorUDT = this + private[spark] override def asNullable: MyDenseVectorUDT = this - override def equals(other: Any): Boolean = other match { - case _: MyDenseVectorUDT => true - case _ => false + override def equals(other: Any): Boolean = other match { + case _: MyDenseVectorUDT => true + case _ => false + } } + } class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { import testImplicits._ private lazy val pointsRDD = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))).toDF() + MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))).toDF() test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } @@ -80,16 +85,16 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) - val features: RDD[MyDenseVector] = - pointsRDD.select('features).rdd.map { case Row(v: MyDenseVector) => v } - val featuresArrays: Array[MyDenseVector] = features.collect() + val features: RDD[UDT.MyDenseVector] = + pointsRDD.select('features).rdd.map { case Row(v: UDT.MyDenseVector) => v } + val featuresArrays: Array[UDT.MyDenseVector] = features.collect() assert(featuresArrays.size === 2) - assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) - assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) + assert(featuresArrays.contains(new UDT.MyDenseVector(Array(0.1, 1.0)))) + assert(featuresArrays.contains(new UDT.MyDenseVector(Array(0.2, 2.0)))) } test("UDTs and UDFs") { - sqlContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + sqlContext.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( sql("SELECT testType(features) from points"), @@ -103,8 +108,8 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT checkAnswer( sqlContext.read.parquet(path), Seq( - Row(1.0, new MyDenseVector(Array(0.1, 1.0))), - Row(0.0, new MyDenseVector(Array(0.2, 2.0))))) + Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) } } @@ -115,18 +120,19 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT checkAnswer( sqlContext.read.parquet(path), Seq( - Row(1.0, new MyDenseVector(Array(0.1, 1.0))), - Row(0.0, new MyDenseVector(Array(0.2, 2.0))))) + Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) } } // Tests to make sure that all operators correctly convert types on the way out. test("Local UDTs") { - val df = Seq((1, new MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec") - df.collect()(0).getAs[MyDenseVector](1) - df.take(1)(0).getAs[MyDenseVector](1) - df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) - df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) + val df = Seq((1, new UDT.MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec") + df.collect()(0).getAs[UDT.MyDenseVector](1) + df.take(1)(0).getAs[UDT.MyDenseVector](1) + df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[UDT.MyDenseVector](0) + df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0) + .getAs[UDT.MyDenseVector](0) } test("UDTs with JSON") { @@ -136,26 +142,47 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT ) val schema = StructType(Seq( StructField("id", IntegerType, false), - StructField("vec", new MyDenseVectorUDT, false) + StructField("vec", new UDT.MyDenseVectorUDT, false) )) val stringRDD = sparkContext.parallelize(data) val jsonRDD = sqlContext.read.schema(schema).json(stringRDD) checkAnswer( jsonRDD, - Row(1, new MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: - Row(2, new MyDenseVector(Array(2.25, 4.5, 8.75))) :: + Row(1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: + Row(2, new UDT.MyDenseVector(Array(2.25, 4.5, 8.75))) :: Nil ) } + test("UDTs with JSON and Dataset") { + val data = Seq( + "{\"id\":1,\"vec\":[1.1,2.2,3.3,4.4]}", + "{\"id\":2,\"vec\":[2.25,4.5,8.75]}" + ) + + val schema = StructType(Seq( + StructField("id", IntegerType, false), + StructField("vec", new UDT.MyDenseVectorUDT, false) + )) + + val stringRDD = sparkContext.parallelize(data) + val jsonDataset = sqlContext.read.schema(schema).json(stringRDD) + .as[(Int, UDT.MyDenseVector)] + checkDataset( + jsonDataset, + (1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))), + (2, new UDT.MyDenseVector(Array(2.25, 4.5, 8.75))) + ) + } + test("SPARK-10472 UserDefinedType.typeName") { assert(IntegerType.typeName === "integer") - assert(new MyDenseVectorUDT().typeName === "mydensevector") + assert(new UDT.MyDenseVectorUDT().typeName === "mydensevector") } test("Catalyst type converter null handling for UDTs") { - val udt = new MyDenseVectorUDT() + val udt = new UDT.MyDenseVectorUDT() val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt) assert(toScalaConverter(null) === null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index e17340c70b..1a7b62ca0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1421,7 +1421,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct, - new MyDenseVectorUDT()) + new UDT.MyDenseVectorUDT()) val fields = dataTypes.zipWithIndex.map { case (dataType, index) => StructField(s"col$index", dataType, nullable = true) } @@ -1445,7 +1445,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Seq(2, 3, 4), Map("a string" -> 2000L), Row(4.75.toFloat, Seq(false, true)), - new MyDenseVector(Array(0.25, 2.25, 4.25))) + new UDT.MyDenseVector(Array(0.25, 2.25, 4.25))) val data = Row.fromSeq(Seq("Spark " + sqlContext.sparkContext.version) ++ constantValues) :: Nil diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 84bb7edf03..bc87d3ef38 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -868,7 +868,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct, - new MyDenseVectorUDT()) + new UDT.MyDenseVectorUDT()) // Right now, we will use SortBasedAggregate to handle UDAFs. // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use // UnsafeRow as the aggregation buffer. While, dataTypes will trigger diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 3d02256792..368fe62ff2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -113,7 +113,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes new StructType() .add("f1", FloatType, nullable = true) .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), - new MyDenseVectorUDT() + new UDT.MyDenseVectorUDT() ).filter(supportsDataType) try { -- GitLab