diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 4ff87edde139a117ff8eb7b2320f43c47199a422..9d4617dda555fcb88b18882e6ac9ba96520d8c1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -343,7 +343,11 @@ object JavaTypeInference { */ def serializerFor(beanClass: Class[_]): CreateNamedStruct = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - serializerFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] + val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean")) + serializerFor(nullSafeInput, TypeToken.of(beanClass)) match { + case expressions.If(_, _, s: CreateNamedStruct) => s + case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) + } } private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { @@ -427,7 +431,7 @@ object JavaTypeInference { case other => val properties = getJavaBeanReadableAndWritableProperties(other) - CreateNamedStruct(properties.flatMap { p => + val nonNullOutput = CreateNamedStruct(properties.flatMap { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val fieldValue = Invoke( @@ -436,6 +440,9 @@ object JavaTypeInference { inferExternalType(fieldType.getRawType)) expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil }) + + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) } } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index ca9e5ad2ea86bd91761cfb48d30cfb3744eaf63e..ffb4c6273ff85899974d460281c1e03a8ae627f4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1380,4 +1380,23 @@ public class JavaDatasetSuite implements Serializable { CircularReference4Bean bean = new CircularReference4Bean(); spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference4Bean.class)); } + + @Test(expected = RuntimeException.class) + public void testNullInTopLevelBean() { + NestedSmallBean bean = new NestedSmallBean(); + // We cannot set null in top-level bean + spark.createDataset(Arrays.asList(bean, null), Encoders.bean(NestedSmallBean.class)); + } + + @Test + public void testSerializeNull() { + NestedSmallBean bean = new NestedSmallBean(); + Encoder<NestedSmallBean> encoder = Encoders.bean(NestedSmallBean.class); + List<NestedSmallBean> beans = Arrays.asList(bean); + Dataset<NestedSmallBean> ds1 = spark.createDataset(beans, encoder); + Assert.assertEquals(beans, ds1.collectAsList()); + Dataset<NestedSmallBean> ds2 = + ds1.map((MapFunction<NestedSmallBean, NestedSmallBean>) b -> b, encoder); + Assert.assertEquals(beans, ds2.collectAsList()); + } }