Skip to content
Snippets Groups Projects
Commit aa48164a authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-12495][SQL] use true as default value for propagateNull in NewInstance

Most of cases we should propagate null when call `NewInstance`, and so far there is only one case we should stop null propagation: create product/java bean. So I think it makes more sense to propagate null by dafault.

This also fixes a bug when encode null array/map, which is firstly discovered in https://github.com/apache/spark/pull/10401

Author: Wenchen Fan <wenchen@databricks.com>

Closes #10443 from cloud-fan/encoder.
parent 932cf442
No related branches found
No related tags found
No related merge requests found
......@@ -178,19 +178,19 @@ object JavaTypeInference {
case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
case c if c == classOf[java.lang.Short] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Integer] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Long] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Double] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Byte] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Float] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Boolean] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.sql.Date] =>
StaticInvoke(
......@@ -298,7 +298,7 @@ object JavaTypeInference {
p.getWriteMethod.getName -> setter
}.toMap
val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other))
val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false)
val result = InitializeJavaBean(newInstance, setters)
if (path.nonEmpty) {
......
......@@ -189,37 +189,37 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.lang.Integer] =>
val boxedType = classOf[java.lang.Integer]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
NewInstance(boxedType, getPath :: Nil, objectType)
case t if t <:< localTypeOf[java.lang.Long] =>
val boxedType = classOf[java.lang.Long]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
NewInstance(boxedType, getPath :: Nil, objectType)
case t if t <:< localTypeOf[java.lang.Double] =>
val boxedType = classOf[java.lang.Double]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
NewInstance(boxedType, getPath :: Nil, objectType)
case t if t <:< localTypeOf[java.lang.Float] =>
val boxedType = classOf[java.lang.Float]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
NewInstance(boxedType, getPath :: Nil, objectType)
case t if t <:< localTypeOf[java.lang.Short] =>
val boxedType = classOf[java.lang.Short]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
NewInstance(boxedType, getPath :: Nil, objectType)
case t if t <:< localTypeOf[java.lang.Byte] =>
val boxedType = classOf[java.lang.Byte]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
NewInstance(boxedType, getPath :: Nil, objectType)
case t if t <:< localTypeOf[java.lang.Boolean] =>
val boxedType = classOf[java.lang.Boolean]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
NewInstance(boxedType, getPath :: Nil, objectType)
case t if t <:< localTypeOf[java.sql.Date] =>
StaticInvoke(
......@@ -349,7 +349,7 @@ object ScalaReflection extends ScalaReflection {
}
}
val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
if (path.nonEmpty) {
expressions.If(
......
......@@ -133,7 +133,7 @@ object ExpressionEncoder {
}
val fromRowExpression =
NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls))
NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false)
new ExpressionEncoder[Any](
schema,
......
......@@ -55,7 +55,6 @@ object RowEncoder {
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
false,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
......@@ -166,7 +165,6 @@ object RowEncoder {
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
false,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)
......
......@@ -165,7 +165,7 @@ case class Invoke(
${obj.code}
${argGen.map(_.code).mkString("\n")}
boolean ${ev.isNull} = ${obj.value} == null;
boolean ${ev.isNull} = ${obj.isNull};
$javaType ${ev.value} =
${ev.isNull} ?
${ctx.defaultValue(dataType)} : ($javaType) $value;
......@@ -178,8 +178,8 @@ object NewInstance {
def apply(
cls: Class[_],
arguments: Seq[Expression],
propagateNull: Boolean = false,
dataType: DataType): NewInstance =
dataType: DataType,
propagateNull: Boolean = true): NewInstance =
new NewInstance(cls, arguments, propagateNull, dataType, None)
}
......@@ -231,7 +231,7 @@ case class NewInstance(
s"new $className($argString)"
}
if (propagateNull) {
if (propagateNull && argGen.nonEmpty) {
val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
s"""
......@@ -248,8 +248,8 @@ case class NewInstance(
s"""
$setup
$javaType ${ev.value} = $constructorCall;
final boolean ${ev.isNull} = ${ev.value} == null;
final $javaType ${ev.value} = $constructorCall;
final boolean ${ev.isNull} = false;
"""
}
}
......
......@@ -46,8 +46,8 @@ class EncoderResolutionSuite extends PlanTest {
toExternalString('a.string),
AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
),
false,
ObjectType(cls))
ObjectType(cls),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
}
......@@ -60,8 +60,8 @@ class EncoderResolutionSuite extends PlanTest {
toExternalString('a.int.cast(StringType)),
AssertNotNull('b.long, cls.getName, "b", "Long")
),
false,
ObjectType(cls))
ObjectType(cls),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
}
}
......@@ -88,11 +88,11 @@ class EncoderResolutionSuite extends PlanTest {
AssertNotNull(
GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
innerCls.getName, "b", "Long")),
false,
ObjectType(innerCls))
ObjectType(innerCls),
propagateNull = false)
)),
false,
ObjectType(cls))
ObjectType(cls),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
}
......@@ -114,11 +114,11 @@ class EncoderResolutionSuite extends PlanTest {
AssertNotNull(
GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
cls.getName, "b", "Long")),
false,
ObjectType(cls)),
ObjectType(cls),
propagateNull = false),
'b.int.cast(LongType)),
false,
ObjectType(classOf[Tuple2[_, _]]))
ObjectType(classOf[Tuple2[_, _]]),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
}
......
......@@ -128,6 +128,9 @@ class ExpressionEncoderSuite extends SparkFunSuite {
encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null")
encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map")
encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple")
encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple")
// Kryo encoders
encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String]))
encodeDecodeTest(new KryoSerializable(15), "kryo object")(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment