From 46e212d2f062ce4546aca812b55d5b5e2f6563ff Mon Sep 17 00:00:00 2001
From: DB Tsai <dbt@netflix.com>
Date: Wed, 12 Apr 2017 11:19:20 +0800
Subject: [PATCH] [SPARK-20291][SQL] NaNvl(FloatType, NullType) should not be
 cast to NaNvl(DoubleType, DoubleType)

## What changes were proposed in this pull request?

`NaNvl(float value, null)` will be converted into `NaNvl(float value, Cast(null, DoubleType))` and finally `NaNvl(Cast(float value, DoubleType), Cast(null, DoubleType))`.

This will cause mismatching in the output type when the input type is float.

By adding extra rule in TypeCoercion can resolve this issue.

## How was this patch tested?

unite tests.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: DB Tsai <dbt@netflix.com>

Closes #17606 from dbtsai/fixNaNvl.

(cherry picked from commit 8ad63ee158815de5ffff7bf03cdf25aef312095f)
Signed-off-by: DB Tsai <dbtsai@dbtsai.com>
---
 .../spark/sql/catalyst/analysis/TypeCoercion.scala |  1 +
 .../sql/catalyst/analysis/TypeCoercionSuite.scala  | 14 ++++++++++----
 .../apache/spark/sql/DataFrameNaFunctions.scala    |  3 +--
 3 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 6d9799fb70..6700fc79d4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -512,6 +512,7 @@ object TypeCoercion {
         NaNvl(l, Cast(r, DoubleType))
       case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
         NaNvl(Cast(l, DoubleType), r)
+      case NaNvl(l, r) if r.dataType == NullType => NaNvl(l, Cast(r, l.dataType))
     }
   }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 590c9d5e84..bfa52be24b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -560,14 +560,20 @@ class TypeCoercionSuite extends PlanTest {
 
   test("nanvl casts") {
     ruleTest(TypeCoercion.FunctionArgumentConversion,
-      NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)),
-      NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
+      NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)),
+      NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
     ruleTest(TypeCoercion.FunctionArgumentConversion,
-      NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)),
-      NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType)))
+      NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)),
+      NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType)))
     ruleTest(TypeCoercion.FunctionArgumentConversion,
       NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
       NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)),
+      NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType)))
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)),
+      NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType)))
   }
 
   test("type coercion for If") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index f6ab770e87..3fbc39142c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -407,8 +407,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
     val quotedColName = "`" + col.name + "`"
     val colValue = col.dataType match {
       case DoubleType | FloatType =>
-        // nanvl only supports these types
-        nanvl(df.col(quotedColName), lit(null).cast(col.dataType))
+        nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types
       case _ => df.col(quotedColName)
     }
     coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name)
-- 
GitLab