From 46e212d2f062ce4546aca812b55d5b5e2f6563ff Mon Sep 17 00:00:00 2001 From: DB Tsai <dbt@netflix.com> Date: Wed, 12 Apr 2017 11:19:20 +0800 Subject: [PATCH] [SPARK-20291][SQL] NaNvl(FloatType, NullType) should not be cast to NaNvl(DoubleType, DoubleType) ## What changes were proposed in this pull request? `NaNvl(float value, null)` will be converted into `NaNvl(float value, Cast(null, DoubleType))` and finally `NaNvl(Cast(float value, DoubleType), Cast(null, DoubleType))`. This will cause mismatching in the output type when the input type is float. By adding extra rule in TypeCoercion can resolve this issue. ## How was this patch tested? unite tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: DB Tsai <dbt@netflix.com> Closes #17606 from dbtsai/fixNaNvl. (cherry picked from commit 8ad63ee158815de5ffff7bf03cdf25aef312095f) Signed-off-by: DB Tsai <dbtsai@dbtsai.com> --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 1 + .../sql/catalyst/analysis/TypeCoercionSuite.scala | 14 ++++++++++---- .../apache/spark/sql/DataFrameNaFunctions.scala | 3 +-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 6d9799fb70..6700fc79d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -512,6 +512,7 @@ object TypeCoercion { NaNvl(l, Cast(r, DoubleType)) case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => NaNvl(Cast(l, DoubleType), r) + case NaNvl(l, r) if r.dataType == NullType => NaNvl(l, Cast(r, l.dataType)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 590c9d5e84..bfa52be24b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -560,14 +560,20 @@ class TypeCoercionSuite extends PlanTest { test("nanvl casts") { ruleTest(TypeCoercion.FunctionArgumentConversion, - NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)), - NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType))) + NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), + NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) ruleTest(TypeCoercion.FunctionArgumentConversion, - NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)), - NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType))) + NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), + NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) + ruleTest(TypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), + NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) + ruleTest(TypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), + NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } test("type coercion for If") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index f6ab770e87..3fbc39142c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -407,8 +407,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val quotedColName = "`" + col.name + "`" val colValue = col.dataType match { case DoubleType | FloatType => - // nanvl only supports these types - nanvl(df.col(quotedColName), lit(null).cast(col.dataType)) + nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types case _ => df.col(quotedColName) } coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name) -- GitLab