From 21f333c635465069b7657d788052d510ffb0779a Mon Sep 17 00:00:00 2001
From: Takeshi Yamamuro <yamamuro@apache.org>
Date: Thu, 16 Mar 2017 08:50:01 +0800
Subject: [PATCH] [SPARK-19751][SQL] Throw an exception if bean class has one's
 own class in fields

## What changes were proposed in this pull request?
The current master throws `StackOverflowError` in `createDataFrame`/`createDataset` if bean has one's own class in fields;
```
public class SelfClassInFieldBean implements Serializable {
  private SelfClassInFieldBean child;
  ...
}
```
This pr added code to throw `UnsupportedOperationException` in that case as soon as possible.

## How was this patch tested?
Added tests in `JavaDataFrameSuite` and `JavaDatasetSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #17188 from maropu/SPARK-19751.
---
 .../sql/catalyst/JavaTypeInference.scala      | 19 ++--
 .../apache/spark/sql/JavaDataFrameSuite.java  | 32 +++++++
 .../apache/spark/sql/JavaDatasetSuite.java    | 87 +++++++++++++++++++
 3 files changed, 132 insertions(+), 6 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index e9d9508e5a..4ff87edde1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -69,7 +69,8 @@ object JavaTypeInference {
    * @param typeToken Java type
    * @return (SQL data type, nullable)
    */
-  private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
+  private def inferDataType(typeToken: TypeToken[_], seenTypeSet: Set[Class[_]] = Set.empty)
+    : (DataType, Boolean) = {
     typeToken.getRawType match {
       case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
         (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
@@ -104,26 +105,32 @@ object JavaTypeInference {
       case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
 
       case _ if typeToken.isArray =>
-        val (dataType, nullable) = inferDataType(typeToken.getComponentType)
+        val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
         (ArrayType(dataType, nullable), true)
 
       case _ if iterableType.isAssignableFrom(typeToken) =>
-        val (dataType, nullable) = inferDataType(elementType(typeToken))
+        val (dataType, nullable) = inferDataType(elementType(typeToken), seenTypeSet)
         (ArrayType(dataType, nullable), true)
 
       case _ if mapType.isAssignableFrom(typeToken) =>
         val (keyType, valueType) = mapKeyValueType(typeToken)
-        val (keyDataType, _) = inferDataType(keyType)
-        val (valueDataType, nullable) = inferDataType(valueType)
+        val (keyDataType, _) = inferDataType(keyType, seenTypeSet)
+        val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet)
         (MapType(keyDataType, valueDataType, nullable), true)
 
       case other =>
+        if (seenTypeSet.contains(other)) {
+          throw new UnsupportedOperationException(
+            "Cannot have circular references in bean class, but got the circular reference " +
+              s"of class $other")
+        }
+
         // TODO: we should only collect properties that have getter and setter. However, some tests
         // pass in scala case class as java bean class which doesn't have getter and setter.
         val properties = getJavaBeanReadableProperties(other)
         val fields = properties.map { property =>
           val returnType = typeToken.method(property.getReadMethod).getReturnType
-          val (dataType, nullable) = inferDataType(returnType)
+          val (dataType, nullable) = inferDataType(returnType, seenTypeSet + other)
           new StructField(property.getName, dataType, nullable)
         }
         (new StructType(fields), true)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index be8d95d0d9..b007093dad 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -423,4 +423,36 @@ public class JavaDataFrameSuite {
     Assert.assertEquals(1L, df.count());
     Assert.assertEquals(2L, df.collectAsList().get(0).getLong(0));
   }
+
+  public class CircularReference1Bean implements Serializable {
+    private CircularReference2Bean child;
+
+    public CircularReference2Bean getChild() {
+      return child;
+    }
+
+    public void setChild(CircularReference2Bean child) {
+      this.child = child;
+    }
+  }
+
+  public class CircularReference2Bean implements Serializable {
+    private CircularReference1Bean child;
+
+    public CircularReference1Bean getChild() {
+      return child;
+    }
+
+    public void setChild(CircularReference1Bean child) {
+      this.child = child;
+    }
+  }
+
+  // Checks a simple case for DataFrame here and put exhaustive tests for the issue
+  // of circular references in `JavaDatasetSuite`.
+  @Test(expected = UnsupportedOperationException.class)
+  public void testCircularReferenceBean() {
+    CircularReference1Bean bean = new CircularReference1Bean();
+    spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class);
+  }
 }
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index d06e35bb44..439cac3dfb 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -1291,4 +1291,91 @@ public class JavaDatasetSuite implements Serializable {
     Assert.assertEquals(df.schema().length(), 0);
     Assert.assertEquals(df.collectAsList().size(), 1);
   }
+
+  public class CircularReference1Bean implements Serializable {
+    private CircularReference2Bean child;
+
+    public CircularReference2Bean getChild() {
+      return child;
+    }
+
+    public void setChild(CircularReference2Bean child) {
+      this.child = child;
+    }
+  }
+
+  public class CircularReference2Bean implements Serializable {
+    private CircularReference1Bean child;
+
+    public CircularReference1Bean getChild() {
+      return child;
+    }
+
+    public void setChild(CircularReference1Bean child) {
+      this.child = child;
+    }
+  }
+
+  public class CircularReference3Bean implements Serializable {
+    private CircularReference3Bean[] child;
+
+    public CircularReference3Bean[] getChild() {
+      return child;
+    }
+
+    public void setChild(CircularReference3Bean[] child) {
+      this.child = child;
+    }
+  }
+
+  public class CircularReference4Bean implements Serializable {
+    private Map<String, CircularReference5Bean> child;
+
+    public Map<String, CircularReference5Bean> getChild() {
+      return child;
+    }
+
+    public void setChild(Map<String, CircularReference5Bean> child) {
+      this.child = child;
+    }
+  }
+
+  public class CircularReference5Bean implements Serializable {
+    private String id;
+    private List<CircularReference4Bean> child;
+
+    public String getId() {
+      return id;
+    }
+
+    public List<CircularReference4Bean> getChild() {
+      return child;
+    }
+
+    public void setId(String id) {
+      this.id = id;
+    }
+
+    public void setChild(List<CircularReference4Bean> child) {
+      this.child = child;
+    }
+  }
+
+  @Test(expected = UnsupportedOperationException.class)
+  public void testCircularReferenceBean1() {
+    CircularReference1Bean bean = new CircularReference1Bean();
+    spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference1Bean.class));
+  }
+
+  @Test(expected = UnsupportedOperationException.class)
+  public void testCircularReferenceBean2() {
+    CircularReference3Bean bean = new CircularReference3Bean();
+    spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference3Bean.class));
+  }
+
+  @Test(expected = UnsupportedOperationException.class)
+  public void testCircularReferenceBean3() {
+    CircularReference4Bean bean = new CircularReference4Bean();
+    spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference4Bean.class));
+  }
 }
-- 
GitLab