diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 206ae2f0e5eb1f28f9ebd18b70914d89244d92a2..198122759e4ad796d80ccaeb9a00e8e2cf55d970 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -251,19 +251,22 @@ object ScalaReflection extends ScalaReflection {
           getPath :: Nil)
 
       case t if t <:< localTypeOf[java.lang.String] =>
-        Invoke(getPath, "toString", ObjectType(classOf[String]))
+        Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false)
 
       case t if t <:< localTypeOf[java.math.BigDecimal] =>
-        Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+        Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
+          returnNullable = false)
 
       case t if t <:< localTypeOf[BigDecimal] =>
-        Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]))
+        Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false)
 
       case t if t <:< localTypeOf[java.math.BigInteger] =>
-        Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]))
+        Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]),
+          returnNullable = false)
 
       case t if t <:< localTypeOf[scala.math.BigInt] =>
-        Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]))
+        Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]),
+          returnNullable = false)
 
       case t if t <:< localTypeOf[Array[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t
@@ -284,7 +287,7 @@ object ScalaReflection extends ScalaReflection {
         val arrayCls = arrayClassFor(elementType)
 
         if (elementNullable) {
-          Invoke(arrayData, "array", arrayCls)
+          Invoke(arrayData, "array", arrayCls, returnNullable = false)
         } else {
           val primitiveMethod = elementType match {
             case t if t <:< definitions.IntTpe => "toIntArray"
@@ -297,7 +300,7 @@ object ScalaReflection extends ScalaReflection {
             case other => throw new IllegalStateException("expect primitive array element type " +
               "but got " + other)
           }
-          Invoke(arrayData, primitiveMethod, arrayCls)
+          Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false)
         }
 
       case t if t <:< localTypeOf[Seq[_]] =>
@@ -330,19 +333,21 @@ object ScalaReflection extends ScalaReflection {
           Invoke(
             MapObjects(
               p => deserializerFor(keyType, Some(p), walkedTypePath),
-              Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
+              Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType),
+                returnNullable = false),
               schemaFor(keyType).dataType),
             "array",
-            ObjectType(classOf[Array[Any]]))
+            ObjectType(classOf[Array[Any]]), returnNullable = false)
 
         val valueData =
           Invoke(
             MapObjects(
               p => deserializerFor(valueType, Some(p), walkedTypePath),
-              Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
+              Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType),
+                returnNullable = false),
               schemaFor(valueType).dataType),
             "array",
-            ObjectType(classOf[Array[Any]]))
+            ObjectType(classOf[Array[Any]]), returnNullable = false)
 
         StaticInvoke(
           ArrayBasedMapData.getClass,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index e95e97b9dc6cba583e9948b02ca6dcdf2f9f4903..0f8282d3b2f1f196a6d59fe4ce4c6bb32d9eb292 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -89,7 +89,7 @@ object RowEncoder {
         udtClass,
         Nil,
         dataType = ObjectType(udtClass), false)
-      Invoke(obj, "serialize", udt, inputObject :: Nil)
+      Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)
 
     case TimestampType =>
       StaticInvoke(
@@ -136,16 +136,18 @@ object RowEncoder {
     case t @ MapType(kt, vt, valueNullable) =>
       val keys =
         Invoke(
-          Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
+          Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]),
+            returnNullable = false),
           "toSeq",
-          ObjectType(classOf[scala.collection.Seq[_]]))
+          ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
       val convertedKeys = serializerFor(keys, ArrayType(kt, false))
 
       val values =
         Invoke(
-          Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
+          Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]),
+            returnNullable = false),
           "toSeq",
-          ObjectType(classOf[scala.collection.Seq[_]]))
+          ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
       val convertedValues = serializerFor(values, ArrayType(vt, valueNullable))
 
       NewInstance(
@@ -262,17 +264,18 @@ object RowEncoder {
         input :: Nil)
 
     case _: DecimalType =>
-      Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+      Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
+        returnNullable = false)
 
     case StringType =>
-      Invoke(input, "toString", ObjectType(classOf[String]))
+      Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false)
 
     case ArrayType(et, nullable) =>
       val arrayData =
         Invoke(
           MapObjects(deserializerFor(_), input, et),
           "array",
-          ObjectType(classOf[Array[_]]))
+          ObjectType(classOf[Array[_]]), returnNullable = false)
       StaticInvoke(
         scala.collection.mutable.WrappedArray.getClass,
         ObjectType(classOf[Seq[_]]),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 53842ef348a57f2ac8fa7368aff2abc14338e2e3..6d94764f1bfacc256c0d29874deffa63c8369684 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -225,25 +225,26 @@ case class Invoke(
       getFuncResult(ev.value, s"${obj.value}.$functionName($argString)")
     } else {
       val funcResult = ctx.freshName("funcResult")
+      // If the function can return null, we do an extra check to make sure our null bit is still
+      // set correctly.
+      val assignResult = if (!returnNullable) {
+        s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;"
+      } else {
+        s"""
+          if ($funcResult != null) {
+            ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
+          } else {
+            ${ev.isNull} = true;
+          }
+        """
+      }
       s"""
         Object $funcResult = null;
         ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")}
-        if ($funcResult == null) {
-          ${ev.isNull} = true;
-        } else {
-          ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
-        }
+        $assignResult
       """
     }
 
-    // If the function can return null, we do an extra check to make sure our null bit is still set
-    // correctly.
-    val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
-      s"${ev.isNull} = ${ev.value} == null;"
-    } else {
-      ""
-    }
-
     val code = s"""
       ${obj.code}
       boolean ${ev.isNull} = true;
@@ -254,7 +255,6 @@ case class Invoke(
         if (!${ev.isNull}) {
           $evaluate
         }
-        $postNullCheck
       }
      """
     ev.copy(code = code)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 82b707537e45f0f87eb2fcfde353f9317679c592..541565344f758aa4a45a29b00eb7b62474947f0e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -96,6 +96,16 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
     checkDataset(dsBoolean.map(e => !e), false, true)
   }
 
+  test("mapPrimitiveArray") {
+    val dsInt = Seq(Array(1, 2), Array(3, 4)).toDS()
+    checkDataset(dsInt.map(e => e), Array(1, 2), Array(3, 4))
+    checkDataset(dsInt.map(e => null: Array[Int]), null, null)
+
+    val dsDouble = Seq(Array(1D, 2D), Array(3D, 4D)).toDS()
+    checkDataset(dsDouble.map(e => e), Array(1D, 2D), Array(3D, 4D))
+    checkDataset(dsDouble.map(e => null: Array[Double]), null, null)
+  }
+
   test("filter") {
     val ds = Seq(1, 2, 3, 4).toDS()
     checkDataset(