From ce10545d3401c555e56a214b7c2f334274803660 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN <ueshin@databricks.com> Date: Wed, 5 Jul 2017 11:24:38 +0800 Subject: [PATCH] [SPARK-21300][SQL] ExternalMapToCatalyst should null-check map key prior to converting to internal value. ## What changes were proposed in this pull request? `ExternalMapToCatalyst` should null-check map key prior to converting to internal value to throw an appropriate Exception instead of something like NPE. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN <ueshin@databricks.com> Closes #18524 from ueshin/issues/SPARK-21300. --- .../spark/sql/catalyst/JavaTypeInference.scala | 1 + .../spark/sql/catalyst/ScalaReflection.scala | 1 + .../catalyst/expressions/objects/objects.scala | 16 +++++++++++++++- .../encoders/ExpressionEncoderSuite.scala | 8 +++++++- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 7683ee7074..90ec699877 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -418,6 +418,7 @@ object JavaTypeInference { inputObject, ObjectType(keyType.getRawType), serializerFor(_, keyType), + keyNullable = true, ObjectType(valueType.getRawType), serializerFor(_, valueType), valueNullable = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d580cf4d33..f3c1e41500 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -494,6 +494,7 @@ object ScalaReflection extends ScalaReflection { inputObject, dataTypeFor(keyType), serializerFor(_, keyType, keyPath, seenTypeSet), + keyNullable = !keyType.typeSymbol.asClass.isPrimitive, dataTypeFor(valueType), serializerFor(_, valueType, valuePath, seenTypeSet), valueNullable = !valueType.typeSymbol.asClass.isPrimitive) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4b651836ff..d6d06aecc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -841,18 +841,21 @@ object ExternalMapToCatalyst { inputMap: Expression, keyType: DataType, keyConverter: Expression => Expression, + keyNullable: Boolean, valueType: DataType, valueConverter: Expression => Expression, valueNullable: Boolean): ExternalMapToCatalyst = { val id = curId.getAndIncrement() val keyName = "ExternalMapToCatalyst_key" + id + val keyIsNull = "ExternalMapToCatalyst_key_isNull" + id val valueName = "ExternalMapToCatalyst_value" + id val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id ExternalMapToCatalyst( keyName, + keyIsNull, keyType, - keyConverter(LambdaVariable(keyName, "false", keyType, false)), + keyConverter(LambdaVariable(keyName, keyIsNull, keyType, keyNullable)), valueName, valueIsNull, valueType, @@ -868,6 +871,8 @@ object ExternalMapToCatalyst { * * @param key the name of the map key variable that used when iterate the map, and used as input for * the `keyConverter` + * @param keyIsNull the nullability of the map key variable that used when iterate the map, and + * used as input for the `keyConverter` * @param keyType the data type of the map key variable that used when iterate the map, and used as * input for the `keyConverter` * @param keyConverter A function that take the `key` as input, and converts it to catalyst format. @@ -883,6 +888,7 @@ object ExternalMapToCatalyst { */ case class ExternalMapToCatalyst private( key: String, + keyIsNull: String, keyType: DataType, keyConverter: Expression, value: String, @@ -913,6 +919,7 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) + ctx.addMutableState("boolean", keyIsNull, "") ctx.addMutableState(keyElementJavaType, key, "") ctx.addMutableState("boolean", valueIsNull, "") ctx.addMutableState(valueElementJavaType, value, "") @@ -950,6 +957,12 @@ case class ExternalMapToCatalyst private( defineEntries -> defineKeyValue } + val keyNullCheck = if (ctx.isPrimitiveType(keyType)) { + s"$keyIsNull = false;" + } else { + s"$keyIsNull = $key == null;" + } + val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { s"$valueIsNull = false;" } else { @@ -972,6 +985,7 @@ case class ExternalMapToCatalyst private( $defineEntries while($entries.hasNext()) { $defineKeyValue + $keyNullCheck $valueNullCheck ${genKeyConverter.code} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 080f11b769..bb1955a1ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -355,12 +355,18 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { checkNullable[String](true) } - test("null check for map key") { + test("null check for map key: String") { val encoder = ExpressionEncoder[Map[String, Int]]() val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2)))) assert(e.getMessage.contains("Cannot use null as map key")) } + test("null check for map key: Integer") { + val encoder = ExpressionEncoder[Map[Integer, String]]() + val e = intercept[RuntimeException](encoder.toRow(Map((1, "a"), (null, "b")))) + assert(e.getMessage.contains("Cannot use null as map key")) + } + private def encodeDecodeTest[T : ExpressionEncoder]( input: T, testName: String): Unit = { -- GitLab