From 7d16776d28da5bcf656f0d8556b15ed3a5edca44 Mon Sep 17 00:00:00 2001
From: mike <mike0sv@gmail.com>
Date: Fri, 25 Aug 2017 07:22:34 +0100
Subject: [PATCH] [SPARK-21255][SQL][WIP] Fixed NPE when creating encoder for
 enum

## What changes were proposed in this pull request?

Fixed NPE when creating encoder for enum.

When you try to create an encoder for Enum type (or bean with enum property) via Encoders.bean(...), it fails with NullPointerException at TypeToken:495.
I did a little research and it turns out, that in JavaTypeInference following code
```
  def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
    val beanInfo = Introspector.getBeanInfo(beanClass)
    beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
      .filter(_.getReadMethod != null)
  }
```
filters out properties named "class", because we wouldn't want to serialize that. But enum types have another property of type Class named "declaringClass", which we are trying to inspect recursively. Eventually we try to inspect ClassLoader class, which has property "defaultAssertionStatus" with no read method, which leads to NPE at TypeToken:495.

I added property name "declaringClass" to filtering to resolve this.

## How was this patch tested?
Unit test in JavaDatasetSuite which creates an encoder for enum

Author: mike <mike0sv@gmail.com>
Author: Mikhail Sveshnikov <mike0sv@gmail.com>

Closes #18488 from mike0sv/enum-support.
---
 .../sql/catalyst/JavaTypeInference.scala      | 40 ++++++++++
 .../catalyst/encoders/ExpressionEncoder.scala | 14 +++-
 .../expressions/objects/objects.scala         |  4 +-
 .../apache/spark/sql/JavaDatasetSuite.java    | 77 +++++++++++++++++++
 4 files changed, 131 insertions(+), 4 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 21363d3ba8..33f6ce080c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
 
 /**
  * Type-inference utilities for POJOs and Java collections.
@@ -118,6 +119,10 @@ object JavaTypeInference {
         val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet)
         (MapType(keyDataType, valueDataType, nullable), true)
 
+      case other if other.isEnum =>
+        (StructType(Seq(StructField(typeToken.getRawType.getSimpleName,
+          StringType, nullable = false))), true)
+
       case other =>
         if (seenTypeSet.contains(other)) {
           throw new UnsupportedOperationException(
@@ -140,6 +145,7 @@ object JavaTypeInference {
   def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
     val beanInfo = Introspector.getBeanInfo(beanClass)
     beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
+      .filterNot(_.getName == "declaringClass")
       .filter(_.getReadMethod != null)
   }
 
@@ -303,6 +309,11 @@ object JavaTypeInference {
           keyData :: valueData :: Nil,
           returnNullable = false)
 
+      case other if other.isEnum =>
+        StaticInvoke(JavaTypeInference.getClass, ObjectType(other), "deserializeEnumName",
+          expressions.Literal.create(other.getEnumConstants.apply(0), ObjectType(other))
+            :: getPath :: Nil)
+
       case other =>
         val properties = getJavaBeanReadableAndWritableProperties(other)
         val setters = properties.map { p =>
@@ -345,6 +356,30 @@ object JavaTypeInference {
     }
   }
 
+  /** Returns a mapping from enum value to int for given enum type */
+  def enumSerializer[T <: Enum[T]](enum: Class[T]): T => UTF8String = {
+    assert(enum.isEnum)
+    inputObject: T =>
+      UTF8String.fromString(inputObject.name())
+  }
+
+  /** Returns value index for given enum type and value */
+  def serializeEnumName[T <: Enum[T]](enum: UTF8String, inputObject: T): UTF8String = {
+    enumSerializer(Utils.classForName(enum.toString).asInstanceOf[Class[T]])(inputObject)
+  }
+
+  /** Returns a mapping from int to enum value for given enum type */
+  def enumDeserializer[T <: Enum[T]](enum: Class[T]): InternalRow => T = {
+    assert(enum.isEnum)
+    value: InternalRow =>
+      Enum.valueOf(enum, value.getUTF8String(0).toString)
+  }
+
+  /** Returns enum value for given enum type and value index */
+  def deserializeEnumName[T <: Enum[T]](typeDummy: T, inputObject: InternalRow): T = {
+    enumDeserializer(typeDummy.getClass.asInstanceOf[Class[T]])(inputObject)
+  }
+
   private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
 
     def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
@@ -429,6 +464,11 @@ object JavaTypeInference {
             valueNullable = true
           )
 
+        case other if other.isEnum =>
+          CreateNamedStruct(expressions.Literal("enum") ::
+          StaticInvoke(JavaTypeInference.getClass, StringType, "serializeEnumName",
+          expressions.Literal.create(other.getName, StringType) :: inputObject :: Nil) :: Nil)
+
         case other =>
           val properties = getJavaBeanReadableAndWritableProperties(other)
           val nonNullOutput = CreateNamedStruct(properties.flatMap { p =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index efc2882f0a..9ed5e12034 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection
 import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance}
 import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
 import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
-import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType}
+import org.apache.spark.sql.types.{BooleanType, DataType, ObjectType, StringType, StructField, StructType}
 import org.apache.spark.util.Utils
 
 /**
@@ -81,9 +81,19 @@ object ExpressionEncoder {
       ClassTag[T](cls))
   }
 
+  def javaEnumSchema[T](beanClass: Class[T]): DataType = {
+    StructType(Seq(StructField("enum",
+      StructType(Seq(StructField(beanClass.getSimpleName, StringType, nullable = false))),
+      nullable = false)))
+  }
+
   // TODO: improve error message for java bean encoder.
   def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
-    val schema = JavaTypeInference.inferDataType(beanClass)._1
+    val schema = if (beanClass.isEnum) {
+      javaEnumSchema(beanClass)
+    } else {
+      JavaTypeInference.inferDataType(beanClass)._1
+    }
     assert(schema.isInstanceOf[StructType])
 
     val serializer = JavaTypeInference.serializerFor(beanClass)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 9b28a18035..7c466fe03c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -154,13 +154,13 @@ case class StaticInvoke(
     val evaluate = if (returnNullable) {
       if (ctx.defaultValue(dataType) == "null") {
         s"""
-          ${ev.value} = $callFunc;
+          ${ev.value} = (($javaType) ($callFunc));
           ${ev.isNull} = ${ev.value} == null;
         """
       } else {
         val boxedResult = ctx.freshName("boxedResult")
         s"""
-          ${ctx.boxedType(dataType)} $boxedResult = $callFunc;
+          ${ctx.boxedType(dataType)} $boxedResult = (($javaType) ($callFunc));
           ${ev.isNull} = $boxedResult == null;
           if (!${ev.isNull}) {
             ${ev.value} = $boxedResult;
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 4ca3b6406a..a344746830 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -1283,6 +1283,83 @@ public class JavaDatasetSuite implements Serializable {
     ds.collectAsList();
   }
 
+  public enum EnumBean {
+    A("www.elgoog.com"),
+    B("www.google.com");
+
+    private String url;
+
+    EnumBean(String url) {
+      this.url = url;
+    }
+
+    public String getUrl() {
+      return url;
+    }
+
+    public void setUrl(String url) {
+      this.url = url;
+    }
+  }
+
+  @Test
+  public void testEnum() {
+    List<EnumBean> data = Arrays.asList(EnumBean.B);
+    Encoder<EnumBean> encoder = Encoders.bean(EnumBean.class);
+    Dataset<EnumBean> ds = spark.createDataset(data, encoder);
+    Assert.assertEquals(ds.collectAsList(), data);
+  }
+
+  public static class BeanWithEnum {
+    EnumBean enumField;
+    String regularField;
+
+    public String getRegularField() {
+      return regularField;
+    }
+
+    public void setRegularField(String regularField) {
+      this.regularField = regularField;
+    }
+
+    public EnumBean getEnumField() {
+      return enumField;
+    }
+
+    public void setEnumField(EnumBean field) {
+      this.enumField = field;
+    }
+
+    public BeanWithEnum(EnumBean enumField, String regularField) {
+      this.enumField = enumField;
+      this.regularField = regularField;
+    }
+
+    public BeanWithEnum() {
+    }
+
+    public String toString() {
+      return "BeanWithEnum(" + enumField  + ", " + regularField + ")";
+    }
+
+    public boolean equals(Object other) {
+      if (other instanceof BeanWithEnum) {
+        BeanWithEnum beanWithEnum = (BeanWithEnum) other;
+        return beanWithEnum.regularField.equals(regularField) && beanWithEnum.enumField.equals(enumField);
+      }
+      return false;
+    }
+  }
+
+  @Test
+  public void testBeanWithEnum() {
+    List<BeanWithEnum> data = Arrays.asList(new BeanWithEnum(EnumBean.A, "mira avenue"),
+            new BeanWithEnum(EnumBean.B, "flower boulevard"));
+    Encoder<BeanWithEnum> encoder = Encoders.bean(BeanWithEnum.class);
+    Dataset<BeanWithEnum> ds = spark.createDataset(data, encoder);
+    Assert.assertEquals(ds.collectAsList(), data);
+  }
+
   public static class EmptyBean implements Serializable {}
 
   @Test
-- 
GitLab