From 09fcf96b8f881988a4bc7fe26a3f6ed12dfb6adb Mon Sep 17 00:00:00 2001
From: Wenchen Fan <cloud0fan@outlook.com>
Date: Tue, 23 Jun 2015 23:11:42 -0700
Subject: [PATCH] [SPARK-8371] [SQL] improve unit test for MaxOf and MinOf and
 fix bugs

a follow up of https://github.com/apache/spark/pull/6813

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #6825 from cloud-fan/cg and squashes the following commits:

43170cc [Wenchen Fan] fix bugs in code gen
---
 .../expressions/codegen/CodeGenerator.scala   |  4 +-
 .../ArithmeticExpressionSuite.scala           | 46 +++++++++++++------
 2 files changed, 34 insertions(+), 16 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index bd5475d206..47c5455435 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -175,8 +175,10 @@ class CodeGenContext {
    * Generate code for compare expression in Java
    */
   def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
+    // java boolean doesn't support > or < operator
+    case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))"
     // use c1 - c2 may overflow
-    case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
+    case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
     case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
     case other => s"$c1.compare($c2)"
   }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 4bbbbe6c7f..6c93698f80 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType}
+import org.apache.spark.sql.types.Decimal
 
 
 class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -123,23 +123,39 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
     }
   }
 
-  test("MaxOf") {
-    checkEvaluation(MaxOf(1, 2), 2)
-    checkEvaluation(MaxOf(2, 1), 2)
-    checkEvaluation(MaxOf(1L, 2L), 2L)
-    checkEvaluation(MaxOf(2L, 1L), 2L)
+  test("MaxOf basic") {
+    testNumericDataTypes { convert =>
+      val small = Literal(convert(1))
+      val large = Literal(convert(2))
+      checkEvaluation(MaxOf(small, large), convert(2))
+      checkEvaluation(MaxOf(large, small), convert(2))
+      checkEvaluation(MaxOf(Literal.create(null, small.dataType), large), convert(2))
+      checkEvaluation(MaxOf(large, Literal.create(null, small.dataType)), convert(2))
+    }
+  }
 
-    checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2)
-    checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2)
+  test("MaxOf for atomic type") {
+    checkEvaluation(MaxOf(true, false), true)
+    checkEvaluation(MaxOf("abc", "bcd"), "bcd")
+    checkEvaluation(MaxOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)),
+      Array(1.toByte, 3.toByte))
   }
 
-  test("MinOf") {
-    checkEvaluation(MinOf(1, 2), 1)
-    checkEvaluation(MinOf(2, 1), 1)
-    checkEvaluation(MinOf(1L, 2L), 1L)
-    checkEvaluation(MinOf(2L, 1L), 1L)
+  test("MinOf basic") {
+    testNumericDataTypes { convert =>
+      val small = Literal(convert(1))
+      val large = Literal(convert(2))
+      checkEvaluation(MinOf(small, large), convert(1))
+      checkEvaluation(MinOf(large, small), convert(1))
+      checkEvaluation(MinOf(Literal.create(null, small.dataType), large), convert(2))
+      checkEvaluation(MinOf(small, Literal.create(null, small.dataType)), convert(1))
+    }
+  }
 
-    checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1)
-    checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1)
+  test("MinOf for atomic type") {
+    checkEvaluation(MinOf(true, false), false)
+    checkEvaluation(MinOf("abc", "bcd"), "abc")
+    checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)),
+      Array(1.toByte, 2.toByte))
   }
 }
-- 
GitLab