diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 90ec699877dec90c2cf6d1b8585cb63f3fc5fd09..21363d3ba82c1a111b986a7c8583dff74fb51c7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -216,7 +216,7 @@ object JavaTypeInference { ObjectType(c), "valueOf", getPath :: Nil, - propagateNull = true) + returnNullable = false) case c if c == classOf[java.sql.Date] => StaticInvoke( @@ -224,7 +224,7 @@ object JavaTypeInference { ObjectType(c), "toJavaDate", getPath :: Nil, - propagateNull = true) + returnNullable = false) case c if c == classOf[java.sql.Timestamp] => StaticInvoke( @@ -232,7 +232,7 @@ object JavaTypeInference { ObjectType(c), "toJavaTimestamp", getPath :: Nil, - propagateNull = true) + returnNullable = false) case c if c == classOf[java.lang.String] => Invoke(getPath, "toString", ObjectType(classOf[String])) @@ -300,7 +300,8 @@ object JavaTypeInference { ArrayBasedMapData.getClass, ObjectType(classOf[JMap[_, _]]), "toJavaMap", - keyData :: valueData :: Nil) + keyData :: valueData :: Nil, + returnNullable = false) case other => val properties = getJavaBeanReadableAndWritableProperties(other) @@ -367,28 +368,32 @@ object JavaTypeInference { classOf[UTF8String], StringType, "fromString", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, DateType, "fromJavaDate", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.math.BigDecimal] => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bea0de4d90c2fff7bd73b5bb2890ab4cb6a1a5a8..814f2c10b909782a02a793444c0e4b693046a786 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -206,51 +206,53 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", - getPath :: Nil) + getPath :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", - getPath :: Nil) + getPath :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.lang.String] => Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) @@ -446,7 +448,8 @@ object ScalaReflection extends ScalaReflection { classOf[UnsafeArrayData], ArrayType(dt, false), "fromPrimitiveArray", - input :: Nil) + input :: Nil, + returnNullable = false) } else { NewInstance( classOf[GenericArrayData], @@ -504,49 +507,56 @@ object ScalaReflection extends ScalaReflection { classOf[UTF8String], StringType, "fromString", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, DateType, "fromJavaDate", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.math.BigDecimal] => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.math.BigInteger] => StaticInvoke( Decimal.getClass, DecimalType.BigIntDecimal, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[scala.math.BigInt] => StaticInvoke( Decimal.getClass, DecimalType.BigIntDecimal, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.lang.Integer] => Invoke(inputObject, "intValue", IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 0f8282d3b2f1f196a6d59fe4ce4c6bb32d9eb292..cc32fac67e92481896b1fbaff3d9d94d7685af66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -96,28 +96,32 @@ object RowEncoder { DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case DateType => StaticInvoke( DateTimeUtils.getClass, DateType, "fromJavaDate", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case d: DecimalType => StaticInvoke( Decimal.getClass, d, "fromDecimal", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case StringType => StaticInvoke( classOf[UTF8String], StringType, "fromString", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t @ ArrayType(et, cn) => et match { @@ -126,7 +130,8 @@ object RowEncoder { classOf[ArrayData], t, "toArrayData", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case _ => MapObjects( element => serializerFor(ValidateExternalType(element, et), et), inputObject, @@ -254,14 +259,16 @@ object RowEncoder { DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", - input :: Nil) + input :: Nil, + returnNullable = false) case DateType => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", - input :: Nil) + input :: Nil, + returnNullable = false) case _: DecimalType => Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), @@ -280,7 +287,8 @@ object RowEncoder { scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", - arrayData :: Nil) + arrayData :: Nil, + returnNullable = false) case MapType(kt, vt, valueNullable) => val keyArrayType = ArrayType(kt, false) @@ -293,7 +301,8 @@ object RowEncoder { ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", - keyData :: valueData :: Nil) + keyData :: valueData :: Nil, + returnNullable = false) case schema @ StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index ce07f4a25c1891f21fb531b5eb819c40027e5e2e..24c06d8b14b548cfab0c697222e43a76cbe860b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -118,17 +118,20 @@ trait InvokeLike extends Expression with NonSQLExpression { * @param arguments An optional list of expressions to pass as arguments to the function. * @param propagateNull When true, and any of the arguments is null, null will be returned instead * of calling the function. + * @param returnNullable When false, indicating the invoked method will always return + * non-null value. */ case class StaticInvoke( staticObject: Class[_], dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends InvokeLike { + propagateNull: Boolean = true, + returnNullable: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") - override def nullable: Boolean = true + override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments override def eval(input: InternalRow): Any = @@ -141,19 +144,40 @@ case class StaticInvoke( val callFunc = s"$objectName.$functionName($argString)" - // If the function can return null, we do an extra check to make sure our null bit is still set - // correctly. - val postNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" + val prepareIsNull = if (nullable) { + s"boolean ${ev.isNull} = $resultIsNull;" } else { + ev.isNull = "false" "" } + val evaluate = if (returnNullable) { + if (ctx.defaultValue(dataType) == "null") { + s""" + ${ev.value} = $callFunc; + ${ev.isNull} = ${ev.value} == null; + """ + } else { + val boxedResult = ctx.freshName("boxedResult") + s""" + ${ctx.boxedType(dataType)} $boxedResult = $callFunc; + ${ev.isNull} = $boxedResult == null; + if (!${ev.isNull}) { + ${ev.value} = $boxedResult; + } + """ + } + } else { + s"${ev.value} = $callFunc;" + } + val code = s""" $argCode - boolean ${ev.isNull} = $resultIsNull; - final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc; - $postNullCheck + $prepareIsNull + $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!$resultIsNull) { + $evaluate + } """ ev.copy(code = code) }