Skip to content
Snippets Groups Projects
Commit 37112fcf authored by hyukjinkwon's avatar hyukjinkwon Committed by Wenchen Fan
Browse files

[SPARK-19666][SQL] Skip a property without getter in Java schema inference and...

[SPARK-19666][SQL] Skip a property without getter in Java schema inference and allow empty bean in encoder creation

## What changes were proposed in this pull request?

This PR proposes to fix two.

**Skip a property without a getter in beans**

Currently, if we use a JavaBean without the getter as below:

```java
public static class BeanWithoutGetter implements Serializable {
  private String a;

  public void setA(String a) {
    this.a = a;
  }
}

BeanWithoutGetter bean = new BeanWithoutGetter();
List<BeanWithoutGetter> data = Arrays.asList(bean);
spark.createDataFrame(data, BeanWithoutGetter.class).show();
```

- Before

It throws an exception as below:

```
java.lang.NullPointerException
	at org.spark_project.guava.reflect.TypeToken.method(TypeToken.java:465)
	at org.apache.spark.sql.catalyst.JavaTypeInference$$anonfun$2.apply(JavaTypeInference.scala:126)
	at org.apache.spark.sql.catalyst.JavaTypeInference$$anonfun$2.apply(JavaTypeInference.scala:125)
```

- After

```
++
||
++
||
++
```

**Supports empty bean in encoder creation**

```java
public static class EmptyBean implements Serializable {}

EmptyBean bean = new EmptyBean();
List<EmptyBean> data = Arrays.asList(bean);
spark.createDataset(data, Encoders.bean(EmptyBean.class)).show();
```

- Before

throws an exception as below:

```
java.lang.UnsupportedOperationException: Cannot infer type for class EmptyBean because it is not bean-compliant
	at org.apache.spark.sql.catalyst.JavaTypeInference$.org$apache$spark$sql$catalyst$JavaTypeInference$$serializerFor(JavaTypeInference.scala:436)
	at org.apache.spark.sql.catalyst.JavaTypeInference$.serializerFor(JavaTypeInference.scala:341)
```

- After

```
++
||
++
||
++
```

## How was this patch tested?

Unit test in `JavaDataFrameSuite`.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #17013 from HyukjinKwon/SPARK-19666.
parent 1f86e795
No related branches found
No related tags found
No related merge requests found
......@@ -117,11 +117,10 @@ object JavaTypeInference {
val (valueDataType, nullable) = inferDataType(valueType)
(MapType(keyDataType, valueDataType, nullable), true)
case _ =>
case other =>
// TODO: we should only collect properties that have getter and setter. However, some tests
// pass in scala case class as java bean class which doesn't have getter and setter.
val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val properties = getJavaBeanReadableProperties(other)
val fields = properties.map { property =>
val returnType = typeToken.method(property.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(returnType)
......@@ -131,10 +130,15 @@ object JavaTypeInference {
}
}
private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
val beanInfo = Introspector.getBeanInfo(beanClass)
beanInfo.getPropertyDescriptors
.filter(p => p.getReadMethod != null && p.getWriteMethod != null)
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
.filter(_.getReadMethod != null)
}
private def getJavaBeanReadableAndWritableProperties(
beanClass: Class[_]): Array[PropertyDescriptor] = {
getJavaBeanReadableProperties(beanClass).filter(_.getWriteMethod != null)
}
private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
......@@ -298,9 +302,7 @@ object JavaTypeInference {
keyData :: valueData :: Nil)
case other =>
val properties = getJavaBeanProperties(other)
assert(properties.length > 0)
val properties = getJavaBeanReadableAndWritableProperties(other)
val setters = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
......@@ -417,21 +419,16 @@ object JavaTypeInference {
)
case other =>
val properties = getJavaBeanProperties(other)
if (properties.length > 0) {
CreateNamedStruct(properties.flatMap { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fieldValue = Invoke(
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
})
} else {
throw new UnsupportedOperationException(
s"Cannot infer type for class ${other.getName} because it is not bean-compliant")
}
val properties = getJavaBeanReadableAndWritableProperties(other)
CreateNamedStruct(properties.flatMap { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fieldValue = Invoke(
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
})
}
}
}
......
......@@ -1090,14 +1090,14 @@ object SQLContext {
*/
private[sql] def beansToRows(
data: Iterator[_],
beanInfo: BeanInfo,
beanClass: Class[_],
attrs: Seq[AttributeReference]): Iterator[InternalRow] = {
val extractors =
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod)
JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod)
val methodsToConverts = extractors.zip(attrs).map { case (e, attr) =>
(e, CatalystTypeConverters.createToCatalystConverter(attr.dataType))
}
data.map{ element =>
data.map { element =>
new GenericInternalRow(
methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) }
): InternalRow
......
......@@ -17,7 +17,6 @@
package org.apache.spark.sql
import java.beans.Introspector
import java.io.Closeable
import java.util.concurrent.atomic.AtomicReference
......@@ -347,8 +346,7 @@ class SparkSession private(
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
// BeanInfo is not serializable so we must rediscover it remotely for each partition.
val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className))
SQLContext.beansToRows(iter, localBeanInfo, attributeSeq)
SQLContext.beansToRows(iter, Utils.classForName(className), attributeSeq)
}
Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd)(self))
}
......@@ -374,8 +372,7 @@ class SparkSession private(
*/
def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = {
val attrSeq = getSchema(beanClass)
val beanInfo = Introspector.getBeanInfo(beanClass)
val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq)
val rows = SQLContext.beansToRows(data.asScala.iterator, beanClass, attrSeq)
Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq))
}
......
......@@ -397,4 +397,21 @@ public class JavaDataFrameSuite {
Assert.assertTrue(filter4.mightContain(i * 3));
}
}
public static class BeanWithoutGetter implements Serializable {
private String a;
public void setA(String a) {
this.a = a;
}
}
@Test
public void testBeanWithoutGetter() {
BeanWithoutGetter bean = new BeanWithoutGetter();
List<BeanWithoutGetter> data = Arrays.asList(bean);
Dataset<Row> df = spark.createDataFrame(data, BeanWithoutGetter.class);
Assert.assertEquals(df.schema().length(), 0);
Assert.assertEquals(df.collectAsList().size(), 1);
}
}
......@@ -1276,4 +1276,15 @@ public class JavaDatasetSuite implements Serializable {
spark.createDataset(data, Encoders.bean(NestedComplicatedJavaBean.class));
ds.collectAsList();
}
public static class EmptyBean implements Serializable {}
@Test
public void testEmptyBean() {
EmptyBean bean = new EmptyBean();
List<EmptyBean> data = Arrays.asList(bean);
Dataset<EmptyBean> df = spark.createDataset(data, Encoders.bean(EmptyBean.class));
Assert.assertEquals(df.schema().length(), 0);
Assert.assertEquals(df.collectAsList().size(), 1);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment