Skip to content
Snippets Groups Projects
Commit d202ad2f authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Michael Armbrust
Browse files

[SPARK-12439][SQL] Fix toCatalystArray and MapObjects

JIRA: https://issues.apache.org/jira/browse/SPARK-12439

In toCatalystArray, we should look at the data type returned by dataTypeFor instead of silentSchemaFor, to determine if the element is native type. An obvious problem is when the element is Option[Int] class, catalsilentSchemaFor will return Int, then we will wrongly recognize the element is native type.

There is another problem when using Option as array element. When we encode data like Seq(Some(1), Some(2), None) with encoder, we will use MapObjects to construct an array for it later. But in MapObjects, we don't check if the return value of lambdaFunction is null or not. That causes a bug that the decoded data for Seq(Some(1), Some(2), None) would be Seq(1, 2, -1), instead of Seq(1, 2, null).

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #10391 from viirya/fix-catalystarray.
parent 8ce645d4
No related branches found
No related tags found
No related merge requests found
...@@ -405,7 +405,7 @@ object ScalaReflection extends ScalaReflection { ...@@ -405,7 +405,7 @@ object ScalaReflection extends ScalaReflection {
def toCatalystArray(input: Expression, elementType: `Type`): Expression = { def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
val externalDataType = dataTypeFor(elementType) val externalDataType = dataTypeFor(elementType)
val Schema(catalystType, nullable) = silentSchemaFor(elementType) val Schema(catalystType, nullable) = silentSchemaFor(elementType)
if (isNativeType(catalystType)) { if (isNativeType(externalDataType)) {
NewInstance( NewInstance(
classOf[GenericArrayData], classOf[GenericArrayData],
input :: Nil, input :: Nil,
......
...@@ -35,7 +35,8 @@ object RowEncoder { ...@@ -35,7 +35,8 @@ object RowEncoder {
def apply(schema: StructType): ExpressionEncoder[Row] = { def apply(schema: StructType): ExpressionEncoder[Row] = {
val cls = classOf[Row] val cls = classOf[Row]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true) val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
val extractExpressions = extractorsFor(inputObject, schema) // We use an If expression to wrap extractorsFor result of StructType
val extractExpressions = extractorsFor(inputObject, schema).asInstanceOf[If].falseValue
val constructExpression = constructorFor(schema) val constructExpression = constructorFor(schema)
new ExpressionEncoder[Row]( new ExpressionEncoder[Row](
schema, schema,
...@@ -129,7 +130,9 @@ object RowEncoder { ...@@ -129,7 +130,9 @@ object RowEncoder {
Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil), Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil),
f.dataType)) f.dataType))
} }
CreateStruct(convertedFields) If(IsNull(inputObject),
Literal.create(null, inputType),
CreateStruct(convertedFields))
} }
private def externalDataTypeFor(dt: DataType): DataType = dt match { private def externalDataTypeFor(dt: DataType): DataType = dt match {
...@@ -220,6 +223,8 @@ object RowEncoder { ...@@ -220,6 +223,8 @@ object RowEncoder {
Literal.create(null, externalDataTypeFor(f.dataType)), Literal.create(null, externalDataTypeFor(f.dataType)),
constructorFor(GetStructField(input, i))) constructorFor(GetStructField(input, i)))
} }
CreateExternalRow(convertedFields) If(IsNull(input),
Literal.create(null, externalDataTypeFor(input.dataType)),
CreateExternalRow(convertedFields))
} }
} }
...@@ -456,10 +456,10 @@ case class MapObjects( ...@@ -456,10 +456,10 @@ case class MapObjects(
($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
$loopNullCheck $loopNullCheck
if (${loopVar.isNull}) { ${genFunction.code}
if (${genFunction.isNull}) {
$convertedArray[$loopIndex] = null; $convertedArray[$loopIndex] = null;
} else { } else {
${genFunction.code}
$convertedArray[$loopIndex] = ${genFunction.value}; $convertedArray[$loopIndex] = ${genFunction.value};
} }
......
...@@ -160,6 +160,9 @@ class ExpressionEncoderSuite extends SparkFunSuite { ...@@ -160,6 +160,9 @@ class ExpressionEncoderSuite extends SparkFunSuite {
productTest(OptionalData(None, None, None, None, None, None, None, None)) productTest(OptionalData(None, None, None, None, None, None, None, None))
encodeDecodeTest(Seq(Some(1), None), "Option in array")
encodeDecodeTest(Map(1 -> Some(10L), 2 -> Some(20L), 3 -> None), "Option in map")
productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
productTest(BoxedData(null, null, null, null, null, null, null)) productTest(BoxedData(null, null, null, null, null, null, null))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment