From aa48164a43bd9ed9eab53fcacbed92819e84eaf7 Mon Sep 17 00:00:00 2001
From: Wenchen Fan <wenchen@databricks.com>
Date: Wed, 30 Dec 2015 10:56:08 -0800
Subject: [PATCH] [SPARK-12495][SQL] use true as default value for
 propagateNull in NewInstance

Most of cases we should propagate null when call `NewInstance`, and so far there is only one case we should stop null propagation: create product/java bean. So I think it makes more sense to propagate null by dafault.

This also fixes a bug when encode null array/map, which is firstly discovered in https://github.com/apache/spark/pull/10401

Author: Wenchen Fan <wenchen@databricks.com>

Closes #10443 from cloud-fan/encoder.
---
 .../sql/catalyst/JavaTypeInference.scala      | 16 ++++++-------
 .../spark/sql/catalyst/ScalaReflection.scala  | 16 ++++++-------
 .../catalyst/encoders/ExpressionEncoder.scala |  2 +-
 .../sql/catalyst/encoders/RowEncoder.scala    |  2 --
 .../sql/catalyst/expressions/objects.scala    | 12 +++++-----
 .../encoders/EncoderResolutionSuite.scala     | 24 +++++++++----------
 .../encoders/ExpressionEncoderSuite.scala     |  3 +++
 7 files changed, 38 insertions(+), 37 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index a1500cbc30..ed153d1f88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -178,19 +178,19 @@ object JavaTypeInference {
       case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
 
       case c if c == classOf[java.lang.Short] =>
-        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+        NewInstance(c, getPath :: Nil, ObjectType(c))
       case c if c == classOf[java.lang.Integer] =>
-        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+        NewInstance(c, getPath :: Nil, ObjectType(c))
       case c if c == classOf[java.lang.Long] =>
-        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+        NewInstance(c, getPath :: Nil, ObjectType(c))
       case c if c == classOf[java.lang.Double] =>
-        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+        NewInstance(c, getPath :: Nil, ObjectType(c))
       case c if c == classOf[java.lang.Byte] =>
-        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+        NewInstance(c, getPath :: Nil, ObjectType(c))
       case c if c == classOf[java.lang.Float] =>
-        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+        NewInstance(c, getPath :: Nil, ObjectType(c))
       case c if c == classOf[java.lang.Boolean] =>
-        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+        NewInstance(c, getPath :: Nil, ObjectType(c))
 
       case c if c == classOf[java.sql.Date] =>
         StaticInvoke(
@@ -298,7 +298,7 @@ object JavaTypeInference {
           p.getWriteMethod.getName -> setter
         }.toMap
 
-        val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other))
+        val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false)
         val result = InitializeJavaBean(newInstance, setters)
 
         if (path.nonEmpty) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 8a22b37d07..9784c96966 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -189,37 +189,37 @@ object ScalaReflection extends ScalaReflection {
       case t if t <:< localTypeOf[java.lang.Integer] =>
         val boxedType = classOf[java.lang.Integer]
         val objectType = ObjectType(boxedType)
-        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+        NewInstance(boxedType, getPath :: Nil, objectType)
 
       case t if t <:< localTypeOf[java.lang.Long] =>
         val boxedType = classOf[java.lang.Long]
         val objectType = ObjectType(boxedType)
-        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+        NewInstance(boxedType, getPath :: Nil, objectType)
 
       case t if t <:< localTypeOf[java.lang.Double] =>
         val boxedType = classOf[java.lang.Double]
         val objectType = ObjectType(boxedType)
-        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+        NewInstance(boxedType, getPath :: Nil, objectType)
 
       case t if t <:< localTypeOf[java.lang.Float] =>
         val boxedType = classOf[java.lang.Float]
         val objectType = ObjectType(boxedType)
-        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+        NewInstance(boxedType, getPath :: Nil, objectType)
 
       case t if t <:< localTypeOf[java.lang.Short] =>
         val boxedType = classOf[java.lang.Short]
         val objectType = ObjectType(boxedType)
-        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+        NewInstance(boxedType, getPath :: Nil, objectType)
 
       case t if t <:< localTypeOf[java.lang.Byte] =>
         val boxedType = classOf[java.lang.Byte]
         val objectType = ObjectType(boxedType)
-        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+        NewInstance(boxedType, getPath :: Nil, objectType)
 
       case t if t <:< localTypeOf[java.lang.Boolean] =>
         val boxedType = classOf[java.lang.Boolean]
         val objectType = ObjectType(boxedType)
-        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+        NewInstance(boxedType, getPath :: Nil, objectType)
 
       case t if t <:< localTypeOf[java.sql.Date] =>
         StaticInvoke(
@@ -349,7 +349,7 @@ object ScalaReflection extends ScalaReflection {
           }
         }
 
-        val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
+        val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
 
         if (path.nonEmpty) {
           expressions.If(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 7a4401cf58..ad4beda9c4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -133,7 +133,7 @@ object ExpressionEncoder {
     }
 
     val fromRowExpression =
-      NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls))
+      NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false)
 
     new ExpressionEncoder[Any](
       schema,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 63bdf05ca7..6f3d5ba84c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -55,7 +55,6 @@ object RowEncoder {
       val obj = NewInstance(
         udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
         Nil,
-        false,
         dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
       Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
 
@@ -166,7 +165,6 @@ object RowEncoder {
       val obj = NewInstance(
         udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
         Nil,
-        false,
         dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
       Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index d40cd96905..fb404c12d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -165,7 +165,7 @@ case class Invoke(
       ${obj.code}
       ${argGen.map(_.code).mkString("\n")}
 
-      boolean ${ev.isNull} = ${obj.value} == null;
+      boolean ${ev.isNull} = ${obj.isNull};
       $javaType ${ev.value} =
         ${ev.isNull} ?
         ${ctx.defaultValue(dataType)} : ($javaType) $value;
@@ -178,8 +178,8 @@ object NewInstance {
   def apply(
       cls: Class[_],
       arguments: Seq[Expression],
-      propagateNull: Boolean = false,
-      dataType: DataType): NewInstance =
+      dataType: DataType,
+      propagateNull: Boolean = true): NewInstance =
     new NewInstance(cls, arguments, propagateNull, dataType, None)
 }
 
@@ -231,7 +231,7 @@ case class NewInstance(
       s"new $className($argString)"
     }
 
-    if (propagateNull) {
+    if (propagateNull && argGen.nonEmpty) {
       val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
 
       s"""
@@ -248,8 +248,8 @@ case class NewInstance(
       s"""
         $setup
 
-        $javaType ${ev.value} = $constructorCall;
-        final boolean ${ev.isNull} = ${ev.value} == null;
+        final $javaType ${ev.value} = $constructorCall;
+        final boolean ${ev.isNull} = false;
       """
     }
   }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 764ffdc094..bc36a55ae0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -46,8 +46,8 @@ class EncoderResolutionSuite extends PlanTest {
           toExternalString('a.string),
           AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
         ),
-        false,
-        ObjectType(cls))
+        ObjectType(cls),
+        propagateNull = false)
       compareExpressions(fromRowExpr, expected)
     }
 
@@ -60,8 +60,8 @@ class EncoderResolutionSuite extends PlanTest {
           toExternalString('a.int.cast(StringType)),
           AssertNotNull('b.long, cls.getName, "b", "Long")
         ),
-        false,
-        ObjectType(cls))
+        ObjectType(cls),
+        propagateNull = false)
       compareExpressions(fromRowExpr, expected)
     }
   }
@@ -88,11 +88,11 @@ class EncoderResolutionSuite extends PlanTest {
               AssertNotNull(
                 GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
                 innerCls.getName, "b", "Long")),
-            false,
-            ObjectType(innerCls))
+            ObjectType(innerCls),
+            propagateNull = false)
         )),
-      false,
-      ObjectType(cls))
+      ObjectType(cls),
+      propagateNull = false)
     compareExpressions(fromRowExpr, expected)
   }
 
@@ -114,11 +114,11 @@ class EncoderResolutionSuite extends PlanTest {
             AssertNotNull(
               GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
               cls.getName, "b", "Long")),
-          false,
-          ObjectType(cls)),
+          ObjectType(cls),
+          propagateNull = false),
         'b.int.cast(LongType)),
-      false,
-      ObjectType(classOf[Tuple2[_, _]]))
+      ObjectType(classOf[Tuple2[_, _]]),
+      propagateNull = false)
     compareExpressions(fromRowExpr, expected)
   }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 7233e0f1b5..666699e18d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -128,6 +128,9 @@ class ExpressionEncoderSuite extends SparkFunSuite {
   encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null")
   encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map")
 
+  encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple")
+  encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple")
+
   // Kryo encoders
   encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String]))
   encodeDecodeTest(new KryoSerializable(15), "kryo object")(
-- 
GitLab