From aa48164a43bd9ed9eab53fcacbed92819e84eaf7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan <wenchen@databricks.com> Date: Wed, 30 Dec 2015 10:56:08 -0800 Subject: [PATCH] [SPARK-12495][SQL] use true as default value for propagateNull in NewInstance Most of cases we should propagate null when call `NewInstance`, and so far there is only one case we should stop null propagation: create product/java bean. So I think it makes more sense to propagate null by dafault. This also fixes a bug when encode null array/map, which is firstly discovered in https://github.com/apache/spark/pull/10401 Author: Wenchen Fan <wenchen@databricks.com> Closes #10443 from cloud-fan/encoder. --- .../sql/catalyst/JavaTypeInference.scala | 16 ++++++------- .../spark/sql/catalyst/ScalaReflection.scala | 16 ++++++------- .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../sql/catalyst/encoders/RowEncoder.scala | 2 -- .../sql/catalyst/expressions/objects.scala | 12 +++++----- .../encoders/EncoderResolutionSuite.scala | 24 +++++++++---------- .../encoders/ExpressionEncoderSuite.scala | 3 +++ 7 files changed, 38 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index a1500cbc30..ed153d1f88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -178,19 +178,19 @@ object JavaTypeInference { case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath case c if c == classOf[java.lang.Short] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Integer] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Long] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Double] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Byte] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Float] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Boolean] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.sql.Date] => StaticInvoke( @@ -298,7 +298,7 @@ object JavaTypeInference { p.getWriteMethod.getName -> setter }.toMap - val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other)) + val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false) val result = InitializeJavaBean(newInstance, setters) if (path.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8a22b37d07..9784c96966 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -189,37 +189,37 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( @@ -349,7 +349,7 @@ object ScalaReflection extends ScalaReflection { } } - val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) + val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) if (path.nonEmpty) { expressions.If( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 7a4401cf58..ad4beda9c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -133,7 +133,7 @@ object ExpressionEncoder { } val fromRowExpression = - NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls)) + NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( schema, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 63bdf05ca7..6f3d5ba84c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -55,7 +55,6 @@ object RowEncoder { val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, - false, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) @@ -166,7 +165,6 @@ object RowEncoder { val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, - false, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index d40cd96905..fb404c12d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -165,7 +165,7 @@ case class Invoke( ${obj.code} ${argGen.map(_.code).mkString("\n")} - boolean ${ev.isNull} = ${obj.value} == null; + boolean ${ev.isNull} = ${obj.isNull}; $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value; @@ -178,8 +178,8 @@ object NewInstance { def apply( cls: Class[_], arguments: Seq[Expression], - propagateNull: Boolean = false, - dataType: DataType): NewInstance = + dataType: DataType, + propagateNull: Boolean = true): NewInstance = new NewInstance(cls, arguments, propagateNull, dataType, None) } @@ -231,7 +231,7 @@ case class NewInstance( s"new $className($argString)" } - if (propagateNull) { + if (propagateNull && argGen.nonEmpty) { val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" s""" @@ -248,8 +248,8 @@ case class NewInstance( s""" $setup - $javaType ${ev.value} = $constructorCall; - final boolean ${ev.isNull} = ${ev.value} == null; + final $javaType ${ev.value} = $constructorCall; + final boolean ${ev.isNull} = false; """ } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 764ffdc094..bc36a55ae0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -46,8 +46,8 @@ class EncoderResolutionSuite extends PlanTest { toExternalString('a.string), AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long") ), - false, - ObjectType(cls)) + ObjectType(cls), + propagateNull = false) compareExpressions(fromRowExpr, expected) } @@ -60,8 +60,8 @@ class EncoderResolutionSuite extends PlanTest { toExternalString('a.int.cast(StringType)), AssertNotNull('b.long, cls.getName, "b", "Long") ), - false, - ObjectType(cls)) + ObjectType(cls), + propagateNull = false) compareExpressions(fromRowExpr, expected) } } @@ -88,11 +88,11 @@ class EncoderResolutionSuite extends PlanTest { AssertNotNull( GetStructField('b.struct('a.int, 'b.long), 1, Some("b")), innerCls.getName, "b", "Long")), - false, - ObjectType(innerCls)) + ObjectType(innerCls), + propagateNull = false) )), - false, - ObjectType(cls)) + ObjectType(cls), + propagateNull = false) compareExpressions(fromRowExpr, expected) } @@ -114,11 +114,11 @@ class EncoderResolutionSuite extends PlanTest { AssertNotNull( GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType), cls.getName, "b", "Long")), - false, - ObjectType(cls)), + ObjectType(cls), + propagateNull = false), 'b.int.cast(LongType)), - false, - ObjectType(classOf[Tuple2[_, _]])) + ObjectType(classOf[Tuple2[_, _]]), + propagateNull = false) compareExpressions(fromRowExpr, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 7233e0f1b5..666699e18d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -128,6 +128,9 @@ class ExpressionEncoderSuite extends SparkFunSuite { encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null") encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map") + encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple") + encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple") + // Kryo encoders encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String])) encodeDecodeTest(new KryoSerializable(15), "kryo object")( -- GitLab