From 4ec931951fea4efbfe5db39cf581704df7d2775b Mon Sep 17 00:00:00 2001
From: Cheng Hao <hao.cheng@intel.com>
Date: Wed, 8 Oct 2014 17:52:27 -0700
Subject: [PATCH] [SPARK-3707] [SQL] Fix bug of type coercion in DIV

Calling `BinaryArithmetic.dataType` will throws exception until it's resolved, but in type coercion rule `Division`, seems doesn't follow this.

Author: Cheng Hao <hao.cheng@intel.com>

Closes #2559 from chenghao-intel/type_coercion and squashes the following commits:

199a85d [Cheng Hao] Simplify the divide rule
dc55218 [Cheng Hao] fix bug of type coercion in div
---
 .../catalyst/analysis/HiveTypeCoercion.scala  |  7 +++-
 .../sql/catalyst/analysis/AnalysisSuite.scala | 40 +++++++++++++++++--
 2 files changed, 42 insertions(+), 5 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 79e5283e86..64881854df 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -348,8 +348,11 @@ trait HiveTypeCoercion {
       case e if !e.childrenResolved => e
 
       // Decimal and Double remain the same
-      case d: Divide if d.dataType == DoubleType => d
-      case d: Divide if d.dataType == DecimalType => d
+      case d: Divide if d.resolved && d.dataType == DoubleType => d
+      case d: Divide if d.resolved && d.dataType == DecimalType => d
+
+      case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType))
+      case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r)
 
       case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
     }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 5809a108ff..7b45738c4f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis
 
 import org.scalatest.{BeforeAndAfter, FunSuite}
 
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
 import org.apache.spark.sql.catalyst.errors.TreeNodeException
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.types.IntegerType
+import org.apache.spark.sql.catalyst.types._
 
 class AnalysisSuite extends FunSuite with BeforeAndAfter {
   val caseSensitiveCatalog = new SimpleCatalog(true)
@@ -33,6 +34,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
     new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false)
 
   val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
+  val testRelation2 = LocalRelation(
+    AttributeReference("a", StringType)(),
+    AttributeReference("b", StringType)(),
+    AttributeReference("c", DoubleType)(),
+    AttributeReference("d", DecimalType)(),
+    AttributeReference("e", ShortType)())
 
   before {
     caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation)
@@ -74,7 +81,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
     val e = intercept[RuntimeException] {
       caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None))
     }
-    assert(e.getMessage === "Table Not Found: tAbLe")
+    assert(e.getMessage == "Table Not Found: tAbLe")
 
     assert(
       caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
@@ -106,4 +113,31 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
     }
     assert(e.getMessage().toLowerCase.contains("unresolved plan"))
   }
+
+  test("divide should be casted into fractional types") {
+    val testRelation2 = LocalRelation(
+      AttributeReference("a", StringType)(),
+      AttributeReference("b", StringType)(),
+      AttributeReference("c", DoubleType)(),
+      AttributeReference("d", DecimalType)(),
+      AttributeReference("e", ShortType)())
+
+    val expr0 = 'a / 2
+    val expr1 = 'a / 'b
+    val expr2 = 'a / 'c
+    val expr3 = 'a / 'd
+    val expr4 = 'e / 'e
+    val plan = caseInsensitiveAnalyze(Project(
+      Alias(expr0, s"Analyzer($expr0)")() ::
+      Alias(expr1, s"Analyzer($expr1)")() ::
+      Alias(expr2, s"Analyzer($expr2)")() ::
+      Alias(expr3, s"Analyzer($expr3)")() ::
+      Alias(expr4, s"Analyzer($expr4)")() :: Nil, testRelation2))
+    val pl = plan.asInstanceOf[Project].projectList
+    assert(pl(0).dataType == DoubleType)
+    assert(pl(1).dataType == DoubleType)
+    assert(pl(2).dataType == DoubleType)
+    assert(pl(3).dataType == DecimalType)
+    assert(pl(4).dataType == DoubleType)
+  }
 }
-- 
GitLab