diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 33f6ce080c339e2d6cd63ae5604bc8ceba9fde95..3ecc137c8cd7f4e28a168933c4f79b7c6e90f22e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.Utils /** * Type-inference utilities for POJOs and Java collections. @@ -120,8 +119,7 @@ object JavaTypeInference { (MapType(keyDataType, valueDataType, nullable), true) case other if other.isEnum => - (StructType(Seq(StructField(typeToken.getRawType.getSimpleName, - StringType, nullable = false))), true) + (StringType, true) case other => if (seenTypeSet.contains(other)) { @@ -310,9 +308,12 @@ object JavaTypeInference { returnNullable = false) case other if other.isEnum => - StaticInvoke(JavaTypeInference.getClass, ObjectType(other), "deserializeEnumName", - expressions.Literal.create(other.getEnumConstants.apply(0), ObjectType(other)) - :: getPath :: Nil) + StaticInvoke( + other, + ObjectType(other), + "valueOf", + Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil, + returnNullable = false) case other => val properties = getJavaBeanReadableAndWritableProperties(other) @@ -356,30 +357,6 @@ object JavaTypeInference { } } - /** Returns a mapping from enum value to int for given enum type */ - def enumSerializer[T <: Enum[T]](enum: Class[T]): T => UTF8String = { - assert(enum.isEnum) - inputObject: T => - UTF8String.fromString(inputObject.name()) - } - - /** Returns value index for given enum type and value */ - def serializeEnumName[T <: Enum[T]](enum: UTF8String, inputObject: T): UTF8String = { - enumSerializer(Utils.classForName(enum.toString).asInstanceOf[Class[T]])(inputObject) - } - - /** Returns a mapping from int to enum value for given enum type */ - def enumDeserializer[T <: Enum[T]](enum: Class[T]): InternalRow => T = { - assert(enum.isEnum) - value: InternalRow => - Enum.valueOf(enum, value.getUTF8String(0).toString) - } - - /** Returns enum value for given enum type and value index */ - def deserializeEnumName[T <: Enum[T]](typeDummy: T, inputObject: InternalRow): T = { - enumDeserializer(typeDummy.getClass.asInstanceOf[Class[T]])(inputObject) - } - private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { @@ -465,9 +442,12 @@ object JavaTypeInference { ) case other if other.isEnum => - CreateNamedStruct(expressions.Literal("enum") :: - StaticInvoke(JavaTypeInference.getClass, StringType, "serializeEnumName", - expressions.Literal.create(other.getName, StringType) :: inputObject :: Nil) :: Nil) + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false) :: Nil, + returnNullable = false) case other => val properties = getJavaBeanReadableAndWritableProperties(other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 9ed5e120344b137578be27c7b8d1591f9364197d..efc2882f0a3d39db7cbb17897538913fe9a280d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} -import org.apache.spark.sql.types.{BooleanType, DataType, ObjectType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType} import org.apache.spark.util.Utils /** @@ -81,19 +81,9 @@ object ExpressionEncoder { ClassTag[T](cls)) } - def javaEnumSchema[T](beanClass: Class[T]): DataType = { - StructType(Seq(StructField("enum", - StructType(Seq(StructField(beanClass.getSimpleName, StringType, nullable = false))), - nullable = false))) - } - // TODO: improve error message for java bean encoder. def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = { - val schema = if (beanClass.isEnum) { - javaEnumSchema(beanClass) - } else { - JavaTypeInference.inferDataType(beanClass)._1 - } + val schema = JavaTypeInference.inferDataType(beanClass)._1 assert(schema.isInstanceOf[StructType]) val serializer = JavaTypeInference.serializerFor(beanClass) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 7c466fe03cdcfefab386ab90b89cd0b3c1f54b89..9b28a18035b1c20fa26760b1b96edf03d9286d1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -154,13 +154,13 @@ case class StaticInvoke( val evaluate = if (returnNullable) { if (ctx.defaultValue(dataType) == "null") { s""" - ${ev.value} = (($javaType) ($callFunc)); + ${ev.value} = $callFunc; ${ev.isNull} = ${ev.value} == null; """ } else { val boxedResult = ctx.freshName("boxedResult") s""" - ${ctx.boxedType(dataType)} $boxedResult = (($javaType) ($callFunc)); + ${ctx.boxedType(dataType)} $boxedResult = $callFunc; ${ev.isNull} = $boxedResult == null; if (!${ev.isNull}) { ${ev.value} = $boxedResult; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a34474683013f4384bbee09ebcef832b9172d5c7..3e57403bede9e9d60186995c7bc8c32ca34a4afb 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1283,13 +1283,13 @@ public class JavaDatasetSuite implements Serializable { ds.collectAsList(); } - public enum EnumBean { + public enum MyEnum { A("www.elgoog.com"), B("www.google.com"); private String url; - EnumBean(String url) { + MyEnum(String url) { this.url = url; } @@ -1302,16 +1302,8 @@ public class JavaDatasetSuite implements Serializable { } } - @Test - public void testEnum() { - List<EnumBean> data = Arrays.asList(EnumBean.B); - Encoder<EnumBean> encoder = Encoders.bean(EnumBean.class); - Dataset<EnumBean> ds = spark.createDataset(data, encoder); - Assert.assertEquals(ds.collectAsList(), data); - } - public static class BeanWithEnum { - EnumBean enumField; + MyEnum enumField; String regularField; public String getRegularField() { @@ -1322,15 +1314,15 @@ public class JavaDatasetSuite implements Serializable { this.regularField = regularField; } - public EnumBean getEnumField() { + public MyEnum getEnumField() { return enumField; } - public void setEnumField(EnumBean field) { + public void setEnumField(MyEnum field) { this.enumField = field; } - public BeanWithEnum(EnumBean enumField, String regularField) { + public BeanWithEnum(MyEnum enumField, String regularField) { this.enumField = enumField; this.regularField = regularField; } @@ -1353,8 +1345,8 @@ public class JavaDatasetSuite implements Serializable { @Test public void testBeanWithEnum() { - List<BeanWithEnum> data = Arrays.asList(new BeanWithEnum(EnumBean.A, "mira avenue"), - new BeanWithEnum(EnumBean.B, "flower boulevard")); + List<BeanWithEnum> data = Arrays.asList(new BeanWithEnum(MyEnum.A, "mira avenue"), + new BeanWithEnum(MyEnum.B, "flower boulevard")); Encoder<BeanWithEnum> encoder = Encoders.bean(BeanWithEnum.class); Dataset<BeanWithEnum> ds = spark.createDataset(data, encoder); Assert.assertEquals(ds.collectAsList(), data);