diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 4ff87edde139a117ff8eb7b2320f43c47199a422..9d4617dda555fcb88b18882e6ac9ba96520d8c1e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -343,7 +343,11 @@ object JavaTypeInference {
    */
   def serializerFor(beanClass: Class[_]): CreateNamedStruct = {
     val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
-    serializerFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
+    val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean"))
+    serializerFor(nullSafeInput, TypeToken.of(beanClass)) match {
+      case expressions.If(_, _, s: CreateNamedStruct) => s
+      case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
+    }
   }
 
   private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
@@ -427,7 +431,7 @@ object JavaTypeInference {
 
         case other =>
           val properties = getJavaBeanReadableAndWritableProperties(other)
-          CreateNamedStruct(properties.flatMap { p =>
+          val nonNullOutput = CreateNamedStruct(properties.flatMap { p =>
             val fieldName = p.getName
             val fieldType = typeToken.method(p.getReadMethod).getReturnType
             val fieldValue = Invoke(
@@ -436,6 +440,9 @@ object JavaTypeInference {
               inferExternalType(fieldType.getRawType))
             expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
           })
+
+          val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
+          expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
       }
     }
   }
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index ca9e5ad2ea86bd91761cfb48d30cfb3744eaf63e..ffb4c6273ff85899974d460281c1e03a8ae627f4 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -1380,4 +1380,23 @@ public class JavaDatasetSuite implements Serializable {
     CircularReference4Bean bean = new CircularReference4Bean();
     spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference4Bean.class));
   }
+
+  @Test(expected = RuntimeException.class)
+  public void testNullInTopLevelBean() {
+    NestedSmallBean bean = new NestedSmallBean();
+    // We cannot set null in top-level bean
+    spark.createDataset(Arrays.asList(bean, null), Encoders.bean(NestedSmallBean.class));
+  }
+
+  @Test
+  public void testSerializeNull() {
+    NestedSmallBean bean = new NestedSmallBean();
+    Encoder<NestedSmallBean> encoder = Encoders.bean(NestedSmallBean.class);
+    List<NestedSmallBean> beans = Arrays.asList(bean);
+    Dataset<NestedSmallBean> ds1 = spark.createDataset(beans, encoder);
+    Assert.assertEquals(beans, ds1.collectAsList());
+    Dataset<NestedSmallBean> ds2 =
+      ds1.map((MapFunction<NestedSmallBean, NestedSmallBean>) b -> b, encoder);
+    Assert.assertEquals(beans, ds2.collectAsList());
+  }
 }