Skip to content
Snippets Groups Projects
Commit a04428fe authored by Takeshi Yamamuro's avatar Takeshi Yamamuro Committed by Wenchen Fan
Browse files

[SPARK-19980][SQL][BACKPORT-2.1] Add NULL checks in Bean serializer

## What changes were proposed in this pull request?
A Bean serializer in `ExpressionEncoder`  could change values when Beans having NULL. A concrete example is as follows;
```
scala> :paste
class Outer extends Serializable {
  private var cls: Inner = _
  def setCls(c: Inner): Unit = cls = c
  def getCls(): Inner = cls
}

class Inner extends Serializable {
  private var str: String = _
  def setStr(s: String): Unit = str = str
  def getStr(): String = str
}

scala> Seq("""{"cls":null}""", """{"cls": {"str":null}}""").toDF().write.text("data")
scala> val encoder = Encoders.bean(classOf[Outer])
scala> val schema = encoder.schema
scala> val df = spark.read.schema(schema).json("data").as[Outer](encoder)
scala> df.show
+------+
|   cls|
+------+
|[null]|
|  null|
+------+

scala> df.map(x => x)(encoder).show()
+------+
|   cls|
+------+
|[null]|
|[null]|     // <-- Value changed
+------+
```

This is because the Bean serializer does not have the NULL-check expressions that the serializer of Scala's product types has. Actually, this value change does not happen in Scala's product types;

```
scala> :paste
case class Outer(cls: Inner)
case class Inner(str: String)

scala> val encoder = Encoders.product[Outer]
scala> val schema = encoder.schema
scala> val df = spark.read.schema(schema).json("data").as[Outer](encoder)
scala> df.show
+------+
|   cls|
+------+
|[null]|
|  null|
+------+

scala> df.map(x => x)(encoder).show()
+------+
|   cls|
+------+
|[null]|
|  null|
+------+
```

This pr added the NULL-check expressions in Bean serializer along with the serializer of Scala's product types.

## How was this patch tested?
Added tests in `JavaDatasetSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #17372 from maropu/SPARK-19980-BACKPORT2.1.
parent 9dfdd2ad
No related branches found
No related tags found
No related merge requests found
......@@ -334,7 +334,11 @@ object JavaTypeInference {
*/
def serializerFor(beanClass: Class[_]): CreateNamedStruct = {
val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
serializerFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean"))
serializerFor(nullSafeInput, TypeToken.of(beanClass)) match {
case expressions.If(_, _, s: CreateNamedStruct) => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
}
private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
......@@ -417,7 +421,7 @@ object JavaTypeInference {
case other =>
val properties = getJavaBeanProperties(other)
if (properties.length > 0) {
CreateNamedStruct(properties.flatMap { p =>
val nonNullOutput = CreateNamedStruct(properties.flatMap { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fieldValue = Invoke(
......@@ -426,6 +430,9 @@ object JavaTypeInference {
inferExternalType(fieldType.getRawType))
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
})
val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
} else {
throw new UnsupportedOperationException(
s"Cannot infer type for class ${other.getName} because it is not bean-compliant")
......
......@@ -1305,4 +1305,28 @@ public class JavaDatasetSuite implements Serializable {
spark.createDataset(data, Encoders.bean(NestedComplicatedJavaBean.class));
ds.collectAsList();
}
@Test(expected = RuntimeException.class)
public void testNullInTopLevelBean() {
NestedSmallBean bean = new NestedSmallBean();
// We cannot set null in top-level bean
spark.createDataset(Arrays.asList(bean, null), Encoders.bean(NestedSmallBean.class));
}
@Test
public void testSerializeNull() {
NestedSmallBean bean = new NestedSmallBean();
Encoder<NestedSmallBean> encoder = Encoders.bean(NestedSmallBean.class);
List<NestedSmallBean> beans = Arrays.asList(bean);
Dataset<NestedSmallBean> ds1 = spark.createDataset(beans, encoder);
Assert.assertEquals(beans, ds1.collectAsList());
Dataset<NestedSmallBean> ds2 =
ds1.map(new MapFunction<NestedSmallBean, NestedSmallBean>() {
@Override
public NestedSmallBean call(NestedSmallBean b) throws Exception {
return b;
}
}, encoder);
Assert.assertEquals(beans, ds2.collectAsList());
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment