Skip to content
Snippets Groups Projects
Commit 21f333c6 authored by Takeshi Yamamuro's avatar Takeshi Yamamuro Committed by Wenchen Fan
Browse files

[SPARK-19751][SQL] Throw an exception if bean class has one's own class in fields

## What changes were proposed in this pull request?
The current master throws `StackOverflowError` in `createDataFrame`/`createDataset` if bean has one's own class in fields;
```
public class SelfClassInFieldBean implements Serializable {
  private SelfClassInFieldBean child;
  ...
}
```
This pr added code to throw `UnsupportedOperationException` in that case as soon as possible.

## How was this patch tested?
Added tests in `JavaDataFrameSuite` and `JavaDatasetSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #17188 from maropu/SPARK-19751.
parent fc931467
No related branches found
No related tags found
No related merge requests found
......@@ -69,7 +69,8 @@ object JavaTypeInference {
* @param typeToken Java type
* @return (SQL data type, nullable)
*/
private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
private def inferDataType(typeToken: TypeToken[_], seenTypeSet: Set[Class[_]] = Set.empty)
: (DataType, Boolean) = {
typeToken.getRawType match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
......@@ -104,26 +105,32 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
case _ if typeToken.isArray =>
val (dataType, nullable) = inferDataType(typeToken.getComponentType)
val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
(ArrayType(dataType, nullable), true)
case _ if iterableType.isAssignableFrom(typeToken) =>
val (dataType, nullable) = inferDataType(elementType(typeToken))
val (dataType, nullable) = inferDataType(elementType(typeToken), seenTypeSet)
(ArrayType(dataType, nullable), true)
case _ if mapType.isAssignableFrom(typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken)
val (keyDataType, _) = inferDataType(keyType)
val (valueDataType, nullable) = inferDataType(valueType)
val (keyDataType, _) = inferDataType(keyType, seenTypeSet)
val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet)
(MapType(keyDataType, valueDataType, nullable), true)
case other =>
if (seenTypeSet.contains(other)) {
throw new UnsupportedOperationException(
"Cannot have circular references in bean class, but got the circular reference " +
s"of class $other")
}
// TODO: we should only collect properties that have getter and setter. However, some tests
// pass in scala case class as java bean class which doesn't have getter and setter.
val properties = getJavaBeanReadableProperties(other)
val fields = properties.map { property =>
val returnType = typeToken.method(property.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(returnType)
val (dataType, nullable) = inferDataType(returnType, seenTypeSet + other)
new StructField(property.getName, dataType, nullable)
}
(new StructType(fields), true)
......
......@@ -423,4 +423,36 @@ public class JavaDataFrameSuite {
Assert.assertEquals(1L, df.count());
Assert.assertEquals(2L, df.collectAsList().get(0).getLong(0));
}
public class CircularReference1Bean implements Serializable {
private CircularReference2Bean child;
public CircularReference2Bean getChild() {
return child;
}
public void setChild(CircularReference2Bean child) {
this.child = child;
}
}
public class CircularReference2Bean implements Serializable {
private CircularReference1Bean child;
public CircularReference1Bean getChild() {
return child;
}
public void setChild(CircularReference1Bean child) {
this.child = child;
}
}
// Checks a simple case for DataFrame here and put exhaustive tests for the issue
// of circular references in `JavaDatasetSuite`.
@Test(expected = UnsupportedOperationException.class)
public void testCircularReferenceBean() {
CircularReference1Bean bean = new CircularReference1Bean();
spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class);
}
}
......@@ -1291,4 +1291,91 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(df.schema().length(), 0);
Assert.assertEquals(df.collectAsList().size(), 1);
}
public class CircularReference1Bean implements Serializable {
private CircularReference2Bean child;
public CircularReference2Bean getChild() {
return child;
}
public void setChild(CircularReference2Bean child) {
this.child = child;
}
}
public class CircularReference2Bean implements Serializable {
private CircularReference1Bean child;
public CircularReference1Bean getChild() {
return child;
}
public void setChild(CircularReference1Bean child) {
this.child = child;
}
}
public class CircularReference3Bean implements Serializable {
private CircularReference3Bean[] child;
public CircularReference3Bean[] getChild() {
return child;
}
public void setChild(CircularReference3Bean[] child) {
this.child = child;
}
}
public class CircularReference4Bean implements Serializable {
private Map<String, CircularReference5Bean> child;
public Map<String, CircularReference5Bean> getChild() {
return child;
}
public void setChild(Map<String, CircularReference5Bean> child) {
this.child = child;
}
}
public class CircularReference5Bean implements Serializable {
private String id;
private List<CircularReference4Bean> child;
public String getId() {
return id;
}
public List<CircularReference4Bean> getChild() {
return child;
}
public void setId(String id) {
this.id = id;
}
public void setChild(List<CircularReference4Bean> child) {
this.child = child;
}
}
@Test(expected = UnsupportedOperationException.class)
public void testCircularReferenceBean1() {
CircularReference1Bean bean = new CircularReference1Bean();
spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference1Bean.class));
}
@Test(expected = UnsupportedOperationException.class)
public void testCircularReferenceBean2() {
CircularReference3Bean bean = new CircularReference3Bean();
spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference3Bean.class));
}
@Test(expected = UnsupportedOperationException.class)
public void testCircularReferenceBean3() {
CircularReference4Bean bean = new CircularReference4Bean();
spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference4Bean.class));
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment