From f6df609dcc4f4a18c0f1c74b1ae0800cf09fa7ae Mon Sep 17 00:00:00 2001
From: Daoyuan Wang <daoyuan.wang@intel.com>
Date: Tue, 2 Dec 2014 14:21:12 -0800
Subject: [PATCH] [SPARK-4593][SQL] Return null when denominator is 0

SELECT max(1/0) FROM src
would return a very large number, which is obviously not right.
For hive-0.12, hive would return `Infinity` for 1/0, while for hive-0.13.1, it is `NULL` for 1/0.
I think it is better to keep our behavior with newer Hive version.
This PR ensures that when the divider is 0, the result of expression should be NULL, same with hive-0.13.1

Author: Daoyuan Wang <daoyuan.wang@intel.com>

Closes #3443 from adrian-wang/div and squashes the following commits:

2e98677 [Daoyuan Wang] fix code gen for divide 0
85c28ba [Daoyuan Wang] temp
36236a5 [Daoyuan Wang] add test cases
6f5716f [Daoyuan Wang] fix comments
cee92bd [Daoyuan Wang] avoid evaluation 2 times
22ecd9a [Daoyuan Wang] fix style
cf28c58 [Daoyuan Wang] divide fix
2dfe50f [Daoyuan Wang] return null when divider is 0 of Double type
---
 .../sql/catalyst/expressions/Expression.scala | 41 +++++++++++++++++++
 .../sql/catalyst/expressions/arithmetic.scala | 13 ++++--
 .../expressions/codegen/CodeGenerator.scala   | 19 ++++++++-
 .../ExpressionEvaluationSuite.scala           | 15 +++++++
 4 files changed, 83 insertions(+), 5 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 39b120e8de..bc45881e42 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -153,6 +153,25 @@ abstract class Expression extends TreeNode[Expression] {
     }
   }
 
+  /**
+   * Evaluation helper function for 1 Fractional children expression.
+   * if the expression result is null, the evaluation result should be null.
+   */
+  @inline
+  protected final def f1(i: Row, e1: Expression, f: ((Fractional[Any], Any) => Any)): Any  = {
+    val evalE1 = e1.eval(i: Row)
+    if(evalE1 == null) {
+      null
+    } else {
+      e1.dataType match {
+        case ft: FractionalType =>
+          f.asInstanceOf[(Fractional[ft.JvmType], ft.JvmType) => ft.JvmType](
+            ft.fractional, evalE1.asInstanceOf[ft.JvmType])
+        case other => sys.error(s"Type $other does not support fractional operations")
+      }
+    }
+  }
+
   /**
    * Evaluation helper function for 2 Integral children expressions. Those expressions are
    * supposed to be in the same data type, and also the return type.
@@ -189,6 +208,28 @@ abstract class Expression extends TreeNode[Expression] {
     }
   }
 
+  /**
+   * Evaluation helper function for 1 Integral children expression.
+   * if the expression result is null, the evaluation result should be null.
+   */
+  @inline
+  protected final def i1(i: Row, e1: Expression, f: ((Integral[Any], Any) => Any)): Any  = {
+    val evalE1 = e1.eval(i)
+    if(evalE1 == null) {
+      null
+    } else {
+      e1.dataType match {
+        case i: IntegralType =>
+          f.asInstanceOf[(Integral[i.JvmType], i.JvmType) => i.JvmType](
+            i.integral, evalE1.asInstanceOf[i.JvmType])
+        case i: FractionalType =>
+          f.asInstanceOf[(Integral[i.JvmType], i.JvmType) => i.JvmType](
+            i.asIntegral, evalE1.asInstanceOf[i.JvmType])
+        case other => sys.error(s"Type $other does not support numeric operations")
+      }
+    }
+  }
+
   /**
    * Evaluation helper function for 2 Comparable children expressions. Those expressions are
    * supposed to be in the same data type, and the return type should be Integer:
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 900b7586ad..7ec18b8419 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -105,11 +105,16 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
 case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
   def symbol = "/"
 
-  override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType]
+  override def nullable = true
 
-  override def eval(input: Row): Any = dataType match {
-    case _: FractionalType => f2(input, left, right, _.div(_, _))
-    case _: IntegralType => i2(input, left , right, _.quot(_, _))
+  override def eval(input: Row): Any = {
+    val evalE2 = right.eval(input)
+    dataType match {
+      case _ if evalE2 == null => null
+      case _ if evalE2 == 0 => null
+      case ft: FractionalType => f1(input, left, _.div(_, evalE2.asInstanceOf[ft.JvmType]))
+      case it: IntegralType => i1(input, left, _.quot(_, evalE2.asInstanceOf[it.JvmType]))
+    }
   }
 
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 67f8d411b6..ab71e15e1f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -359,7 +359,24 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
       case Add(e1, e2) =>      (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" }
       case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" }
       case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" }
-      case Divide(e1, e2) =>   (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 / $eval2" }
+      case Divide(e1, e2) =>
+        val eval1 = expressionEvaluator(e1)
+        val eval2 = expressionEvaluator(e2)
+
+        eval1.code ++ eval2.code ++
+        q"""
+          var $nullTerm = false
+          var $primitiveTerm: ${termForType(e1.dataType)} = 0
+
+          if (${eval1.nullTerm} || ${eval2.nullTerm} ) {
+            $nullTerm = true
+          } else if (${eval2.primitiveTerm} == 0)
+            $nullTerm = true
+          else {
+            $nullTerm = false
+            $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm}
+          }
+         """.children
 
       case IsNotNull(e) =>
         val eval = expressionEvaluator(e)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 3f5b9f698f..25f5642488 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -149,6 +149,21 @@ class ExpressionEvaluationSuite extends FunSuite {
     checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true)
   }
 
+  test("Divide") {
+    checkEvaluation(Divide(Literal(2), Literal(1)), 2)
+    checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5)
+    checkEvaluation(Divide(Literal(1), Literal(2)), 0)
+    checkEvaluation(Divide(Literal(1), Literal(0)), null)
+    checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null)
+    checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null)
+    checkEvaluation(Divide(Literal(0), Literal(null, IntegerType)), null)
+    checkEvaluation(Divide(Literal(1), Literal(null, IntegerType)), null)
+    checkEvaluation(Divide(Literal(null, IntegerType), Literal(0)), null)
+    checkEvaluation(Divide(Literal(null, DoubleType), Literal(0.0)), null)
+    checkEvaluation(Divide(Literal(null, IntegerType), Literal(1)), null)
+    checkEvaluation(Divide(Literal(null, IntegerType), Literal(null, IntegerType)), null)
+  }
+
   test("INSET") {
     val hS = HashSet[Any]() + 1 + 2
     val nS = HashSet[Any]() + 1 + 2 + null
-- 
GitLab