diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index bd5475d2066fc4e0e28030166a895d9672c3e61d..47c5455435ec60797b6dbb47e842a1ba42217d3d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -175,8 +175,10 @@ class CodeGenContext {
    * Generate code for compare expression in Java
    */
   def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
+    // java boolean doesn't support > or < operator
+    case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))"
     // use c1 - c2 may overflow
-    case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
+    case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
     case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
     case other => s"$c1.compare($c2)"
   }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 4bbbbe6c7f0917123d83c0e6bdfe45680ab08444..6c93698f8017bec90e7156e03eb58f1ac2d596ca 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType}
+import org.apache.spark.sql.types.Decimal
 
 
 class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -123,23 +123,39 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
     }
   }
 
-  test("MaxOf") {
-    checkEvaluation(MaxOf(1, 2), 2)
-    checkEvaluation(MaxOf(2, 1), 2)
-    checkEvaluation(MaxOf(1L, 2L), 2L)
-    checkEvaluation(MaxOf(2L, 1L), 2L)
+  test("MaxOf basic") {
+    testNumericDataTypes { convert =>
+      val small = Literal(convert(1))
+      val large = Literal(convert(2))
+      checkEvaluation(MaxOf(small, large), convert(2))
+      checkEvaluation(MaxOf(large, small), convert(2))
+      checkEvaluation(MaxOf(Literal.create(null, small.dataType), large), convert(2))
+      checkEvaluation(MaxOf(large, Literal.create(null, small.dataType)), convert(2))
+    }
+  }
 
-    checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2)
-    checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2)
+  test("MaxOf for atomic type") {
+    checkEvaluation(MaxOf(true, false), true)
+    checkEvaluation(MaxOf("abc", "bcd"), "bcd")
+    checkEvaluation(MaxOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)),
+      Array(1.toByte, 3.toByte))
   }
 
-  test("MinOf") {
-    checkEvaluation(MinOf(1, 2), 1)
-    checkEvaluation(MinOf(2, 1), 1)
-    checkEvaluation(MinOf(1L, 2L), 1L)
-    checkEvaluation(MinOf(2L, 1L), 1L)
+  test("MinOf basic") {
+    testNumericDataTypes { convert =>
+      val small = Literal(convert(1))
+      val large = Literal(convert(2))
+      checkEvaluation(MinOf(small, large), convert(1))
+      checkEvaluation(MinOf(large, small), convert(1))
+      checkEvaluation(MinOf(Literal.create(null, small.dataType), large), convert(2))
+      checkEvaluation(MinOf(small, Literal.create(null, small.dataType)), convert(1))
+    }
+  }
 
-    checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1)
-    checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1)
+  test("MinOf for atomic type") {
+    checkEvaluation(MinOf(true, false), false)
+    checkEvaluation(MinOf("abc", "bcd"), "abc")
+    checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)),
+      Array(1.toByte, 2.toByte))
   }
 }