From 30e89111d673776a6b59b11cdb29ab8713ba6f7c Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@databricks.com>
Date: Sun, 2 Aug 2015 20:12:03 -0700
Subject: [PATCH] [SPARK-9546][SQL] Centralize orderable data type checking.

This pull request creates two isOrderable functions in RowOrdering that can be used to check whether a data type or a sequence of expressions can be used in sorting.

Author: Reynold Xin <rxin@databricks.com>

Closes #7880 from rxin/SPARK-9546 and squashes the following commits:

f9e322d [Reynold Xin] Fixed tests.
0439b43 [Reynold Xin] [SPARK-9546][SQL] Centralize orderable data type checking.
---
 .../sql/catalyst/analysis/CheckAnalysis.scala |  8 +-
 .../expressions/ExpectsInputTypes.scala       |  4 +-
 .../sql/catalyst/expressions/Expression.scala |  2 +-
 .../catalyst/expressions/RowOrdering.scala    | 93 +++++++++++++++++++
 .../sql/catalyst/expressions/SortOrder.scala  |  9 ++
 .../expressions/codegen/CodeGenerator.scala   | 12 ++-
 .../codegen/GenerateOrdering.scala            |  2 -
 .../expressions/collectionOperations.scala    | 21 +++--
 .../spark/sql/catalyst/expressions/rows.scala | 44 ---------
 .../spark/sql/catalyst/util/TypeUtils.scala   | 27 +++---
 .../apache/spark/sql/types/StructType.scala   | 12 ---
 .../analysis/AnalysisErrorSuite.scala         | 14 +--
 .../ExpressionTypeCheckingSuite.scala         | 50 +++++-----
 .../spark/sql/execution/SparkStrategies.scala | 14 +--
 .../spark/sql/DataFrameFunctionsSuite.scala   |  5 +-
 15 files changed, 173 insertions(+), 144 deletions(-)
 create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 364569d8f0..187b238045 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -130,11 +130,9 @@ trait CheckAnalysis {
 
           case Sort(orders, _, _) =>
             orders.foreach { order =>
-              order.dataType match {
-                case t: AtomicType => // OK
-                case NullType => // OK
-                case t =>
-                  failAnalysis(s"Sorting is not supported for columns of type ${t.simpleString}")
+              if (!RowOrdering.isOrderable(order.dataType)) {
+                failAnalysis(
+                  s"sorting is not supported for columns of type ${order.dataType.simpleString}")
               }
             }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
index abe6457747..2dcbd4eb15 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
@@ -44,8 +44,8 @@ trait ExpectsInputTypes extends Expression {
   override def checkInputDataTypes(): TypeCheckResult = {
     val mismatches = children.zip(inputTypes).zipWithIndex.collect {
       case ((child, expected), idx) if !expected.acceptsType(child.dataType) =>
-        s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " +
-          s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}."
+        s"argument ${idx + 1} requires ${expected.simpleString} type, " +
+          s"however, '${child.prettyString}' is of ${child.dataType.simpleString} type."
     }
 
     if (mismatches.isEmpty) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 2842b3ec5a..ef2fc2e8c2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -420,7 +420,7 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
       TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
         s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
     } else if (!inputType.acceptsType(left.dataType)) {
-      TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," +
+      TypeCheckResult.TypeCheckFailure(s"'$prettyString' requires ${inputType.simpleString} type," +
         s" not ${left.dataType.simpleString}")
     } else {
       TypeCheckResult.TypeCheckSuccess
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala
new file mode 100644
index 0000000000..873f5324c5
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
+
+
+/**
+ * An interpreted row ordering comparator.
+ */
+class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {
+
+  def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
+    this(ordering.map(BindReferences.bindReference(_, inputSchema)))
+
+  def compare(a: InternalRow, b: InternalRow): Int = {
+    var i = 0
+    while (i < ordering.size) {
+      val order = ordering(i)
+      val left = order.child.eval(a)
+      val right = order.child.eval(b)
+
+      if (left == null && right == null) {
+        // Both null, continue looking.
+      } else if (left == null) {
+        return if (order.direction == Ascending) -1 else 1
+      } else if (right == null) {
+        return if (order.direction == Ascending) 1 else -1
+      } else {
+        val comparison = order.dataType match {
+          case dt: AtomicType if order.direction == Ascending =>
+            dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
+          case dt: AtomicType if order.direction == Descending =>
+            dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
+          case s: StructType if order.direction == Ascending =>
+            s.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
+          case s: StructType if order.direction == Descending =>
+            s.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
+          case other =>
+            throw new IllegalArgumentException(s"Type $other does not support ordered operations")
+        }
+        if (comparison != 0) {
+          return comparison
+        }
+      }
+      i += 1
+    }
+    return 0
+  }
+}
+
+object RowOrdering {
+
+  /**
+   * Returns true iff the data type can be ordered (i.e. can be sorted).
+   */
+  def isOrderable(dataType: DataType): Boolean = dataType match {
+    case NullType => true
+    case dt: AtomicType => true
+    case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType))
+    case _ => false
+  }
+
+  /**
+   * Returns true iff outputs from the expressions can be ordered.
+   */
+  def isOrderable(exprs: Seq[Expression]): Boolean = exprs.forall(e => isOrderable(e.dataType))
+
+  /**
+   * Creates a [[RowOrdering]] for the given schema, in natural ascending order.
+   */
+  def forSchema(dataTypes: Seq[DataType]): RowOrdering = {
+    new RowOrdering(dataTypes.zipWithIndex.map {
+      case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
+    })
+  }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 5eb5b0d176..f6a872ba44 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator
@@ -36,6 +37,14 @@ case class SortOrder(child: Expression, direction: SortDirection)
   /** Sort order is not foldable because we don't have an eval for it. */
   override def foldable: Boolean = false
 
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (RowOrdering.isOrderable(dataType)) {
+      TypeCheckResult.TypeCheckSuccess
+    } else {
+      TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.simpleString}")
+    }
+  }
+
   override def dataType: DataType = child.dataType
   override def nullable: Boolean = child.nullable
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 3c91227d06..03ec4b4b4e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -220,7 +220,11 @@ class CodeGenContext {
   }
 
   /**
-   * Generates code for compare expression in Java.
+   * Generates code for comparing two expressions.
+   *
+   * @param dataType data type of the expressions
+   * @param c1 name of the variable of expression 1's output
+   * @param c2 name of the variable of expression 2's output
    */
   def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
     // java boolean doesn't support > or < operator
@@ -231,7 +235,7 @@ class CodeGenContext {
     case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
     case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
     case NullType => "0"
-    case schema: StructType if schema.supportOrdering(schema) =>
+    case schema: StructType =>
       val comparisons = GenerateOrdering.genComparisons(this, schema)
       val compareFunc = freshName("compareStruct")
       val funcCode: String =
@@ -245,8 +249,8 @@ class CodeGenContext {
       addNewFunction(compareFunc, funcCode)
       s"this.$compareFunc($c1, $c2)"
     case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
-    case _ => throw new IllegalArgumentException(
-      "cannot generate compare code for un-comparable type")
+    case _ =>
+      throw new IllegalArgumentException("cannot generate compare code for un-comparable type")
   }
 
   /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index 4da91ed8d7..42be394c3b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.catalyst.expressions.codegen
 
 import org.apache.spark.Logging
-import org.apache.spark.annotation.Private
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types.StructType
@@ -26,7 +25,6 @@ import org.apache.spark.sql.types.StructType
 /**
  * Inherits some default implementation for Java from `Ordering[Row]`
  */
-@Private
 class BaseOrdering extends Ordering[InternalRow] {
   def compare(a: InternalRow, b: InternalRow): Int = {
     throw new UnsupportedOperationException
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 80b8da23e8..6ccb56578f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -20,6 +20,7 @@ import java.util.Comparator
 
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, CodeGenContext, GeneratedExpressionCode}
+import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
 
 /**
@@ -54,15 +55,17 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
   override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)
 
   override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
-    case _ @ ArrayType(n: AtomicType, _) => TypeCheckResult.TypeCheckSuccess
-    case _ @ ArrayType(n, _) => TypeCheckResult.TypeCheckFailure(
-                    s"Type $n is not the AtomicType, we can not perform the ordering operations")
-    case other =>
-      TypeCheckResult.TypeCheckFailure(s"ArrayType(AtomicType) is expected, but we got $other")
+    case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
+      TypeCheckResult.TypeCheckSuccess
+    case ArrayType(dt, _) =>
+      TypeCheckResult.TypeCheckFailure(
+        s"$prettyName does not support sorting array of type ${dt.simpleString}")
+    case _ =>
+      TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
   }
 
   @transient
-  private lazy val lt = {
+  private lazy val lt: Comparator[Any] = {
     val ordering = base.dataType match {
       case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
     }
@@ -83,7 +86,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
   }
 
   @transient
-  private lazy val gt = {
+  private lazy val gt: Comparator[Any] = {
     val ordering = base.dataType match {
       case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
     }
@@ -106,9 +109,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
   override def nullSafeEval(array: Any, ascending: Any): Any = {
     val elementType = base.dataType.asInstanceOf[ArrayType].elementType
     val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
-    java.util.Arrays.sort(
-      data,
-      if (ascending.asInstanceOf[Boolean]) lt else gt)
+    java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt)
     new GenericArrayData(data.asInstanceOf[Array[Any]])
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 7e1031c755..d04434b953 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -121,47 +121,3 @@ class GenericMutableRow(val values: Array[Any]) extends MutableRow {
 
   override def copy(): InternalRow = new GenericInternalRow(values.clone())
 }
-
-class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {
-  def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
-    this(ordering.map(BindReferences.bindReference(_, inputSchema)))
-
-  def compare(a: InternalRow, b: InternalRow): Int = {
-    var i = 0
-    while (i < ordering.size) {
-      val order = ordering(i)
-      val left = order.child.eval(a)
-      val right = order.child.eval(b)
-
-      if (left == null && right == null) {
-        // Both null, continue looking.
-      } else if (left == null) {
-        return if (order.direction == Ascending) -1 else 1
-      } else if (right == null) {
-        return if (order.direction == Ascending) 1 else -1
-      } else {
-        val comparison = order.dataType match {
-          case n: AtomicType if order.direction == Ascending =>
-            n.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
-          case n: AtomicType if order.direction == Descending =>
-            n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
-          case s: StructType if order.direction == Ascending =>
-            s.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
-          case s: StructType if order.direction == Descending =>
-            s.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
-          case other => sys.error(s"Type $other does not support ordered operations")
-        }
-        if (comparison != 0) return comparison
-      }
-      i += 1
-    }
-    return 0
-  }
-}
-
-object RowOrdering {
-  def forSchema(dataTypes: Seq[DataType]): RowOrdering =
-    new RowOrdering(dataTypes.zipWithIndex.map {
-      case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
-    })
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 2f50d40fe2..0b41f92c61 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -18,39 +18,34 @@
 package org.apache.spark.sql.catalyst.util
 
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.RowOrdering
 import org.apache.spark.sql.types._
 
 /**
  * Helper functions to check for valid data types.
  */
 object TypeUtils {
-  def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = {
-    if (t.isInstanceOf[NumericType] || t == NullType) {
+  def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = {
+    if (dt.isInstanceOf[NumericType] || dt == NullType) {
       TypeCheckResult.TypeCheckSuccess
     } else {
-      TypeCheckResult.TypeCheckFailure(s"$caller accepts numeric types, not $t")
+      TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt")
     }
   }
 
-  def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = {
-    t match {
-      case i: AtomicType => TypeCheckResult.TypeCheckSuccess
-      case n: NullType => TypeCheckResult.TypeCheckSuccess
-      case s: StructType =>
-        if (s.supportOrdering(s)) {
-          TypeCheckResult.TypeCheckSuccess
-        } else {
-          TypeCheckResult.TypeCheckFailure(s"Fields in $s do not support ordering")
-        }
-      case other => TypeCheckResult.TypeCheckFailure(s"$t doesn't support ordering on $caller")
+  def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = {
+    if (RowOrdering.isOrderable(dt)) {
+      TypeCheckResult.TypeCheckSuccess
+    } else {
+      TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt")
     }
-
   }
 
   def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = {
     if (types.distinct.size > 1) {
       TypeCheckResult.TypeCheckFailure(
-        s"input to $caller should all be the same type, but it's ${types.mkString("[", ", ", "]")}")
+        s"input to $caller should all be the same type, but it's " +
+          types.map(_.simpleString).mkString("[", ", ", "]"))
     } else {
       TypeCheckResult.TypeCheckSuccess
     }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 2f23144858..6928707f7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -302,18 +302,6 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
   }
 
   private[sql] val ordering = RowOrdering.forSchema(this.fields.map(_.dataType))
-
-  private[sql] def supportOrdering(s: StructType): Boolean = {
-    s.fields.forall { f =>
-      if (f.dataType.isInstanceOf[AtomicType]) {
-        true
-      } else if (f.dataType.isInstanceOf[StructType]) {
-        supportOrdering(f.dataType.asInstanceOf[StructType])
-      } else {
-        false
-      }
-    }
-  }
 }
 
 object StructType extends AbstractDataType {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index aa19cdce31..26935c6e3b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -68,22 +68,22 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
   errorTest(
     "single invalid type, single arg",
     testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)),
-    "cannot resolve" :: "testfunction" :: "argument 1" :: "expected to be of type int" ::
-    "'null' is of type date" ::Nil)
+    "cannot resolve" :: "testfunction" :: "argument 1" :: "requires int type" ::
+    "'null' is of date type" ::Nil)
 
   errorTest(
     "single invalid type, second arg",
     testRelation.select(
       TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)),
-    "cannot resolve" :: "testfunction" :: "argument 2" :: "expected to be of type int" ::
-    "'null' is of type date" ::Nil)
+    "cannot resolve" :: "testfunction" :: "argument 2" :: "requires int type" ::
+    "'null' is of date type" ::Nil)
 
   errorTest(
     "multiple invalid type",
     testRelation.select(
       TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)),
     "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" ::
-    "expected to be of type int" :: "'null' is of type date" ::Nil)
+    "requires int type" :: "'null' is of date type" ::Nil)
 
   errorTest(
     "unresolved window function",
@@ -111,12 +111,12 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
   errorTest(
     "bad casts",
     testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
-    "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
+  "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
 
   errorTest(
     "sorting by unsupported column types",
     listRelation.orderBy('list.asc),
-    "sorting" :: "type" :: "array<int>" :: Nil)
+    "sort" :: "type" :: "array<int>" :: Nil)
 
   errorTest(
     "non-boolean filters",
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 8f616ae9d2..c9bcc68f02 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -53,9 +53,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
   }
 
   test("check types for unary arithmetic") {
-    assertError(UnaryMinus('stringField), "type (numeric or calendarinterval)")
-    assertError(Abs('stringField), "expected to be of type numeric")
-    assertError(BitwiseNot('stringField), "expected to be of type integral")
+    assertError(UnaryMinus('stringField), "(numeric or calendarinterval) type")
+    assertError(Abs('stringField), "requires numeric type")
+    assertError(BitwiseNot('stringField), "requires integral type")
   }
 
   test("check types for binary arithmetic") {
@@ -78,21 +78,21 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertErrorForDifferingTypes(MaxOf('intField, 'booleanField))
     assertErrorForDifferingTypes(MinOf('intField, 'booleanField))
 
-    assertError(Add('booleanField, 'booleanField), "accepts (numeric or calendarinterval) type")
+    assertError(Add('booleanField, 'booleanField), "requires (numeric or calendarinterval) type")
     assertError(Subtract('booleanField, 'booleanField),
-      "accepts (numeric or calendarinterval) type")
-    assertError(Multiply('booleanField, 'booleanField), "accepts numeric type")
-    assertError(Divide('booleanField, 'booleanField), "accepts numeric type")
-    assertError(Remainder('booleanField, 'booleanField), "accepts numeric type")
+      "requires (numeric or calendarinterval) type")
+    assertError(Multiply('booleanField, 'booleanField), "requires numeric type")
+    assertError(Divide('booleanField, 'booleanField), "requires numeric type")
+    assertError(Remainder('booleanField, 'booleanField), "requires numeric type")
 
-    assertError(BitwiseAnd('booleanField, 'booleanField), "accepts integral type")
-    assertError(BitwiseOr('booleanField, 'booleanField), "accepts integral type")
-    assertError(BitwiseXor('booleanField, 'booleanField), "accepts integral type")
+    assertError(BitwiseAnd('booleanField, 'booleanField), "requires integral type")
+    assertError(BitwiseOr('booleanField, 'booleanField), "requires integral type")
+    assertError(BitwiseXor('booleanField, 'booleanField), "requires integral type")
 
     assertError(MaxOf('complexField, 'complexField),
-      s"accepts ${TypeCollection.Ordered.simpleString} type")
+      s"requires ${TypeCollection.Ordered.simpleString} type")
     assertError(MinOf('complexField, 'complexField),
-      s"accepts ${TypeCollection.Ordered.simpleString} type")
+      s"requires ${TypeCollection.Ordered.simpleString} type")
   }
 
   test("check types for predicates") {
@@ -116,13 +116,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
 
     assertError(LessThan('complexField, 'complexField),
-      s"accepts ${TypeCollection.Ordered.simpleString} type")
+      s"requires ${TypeCollection.Ordered.simpleString} type")
     assertError(LessThanOrEqual('complexField, 'complexField),
-      s"accepts ${TypeCollection.Ordered.simpleString} type")
+      s"requires ${TypeCollection.Ordered.simpleString} type")
     assertError(GreaterThan('complexField, 'complexField),
-      s"accepts ${TypeCollection.Ordered.simpleString} type")
+      s"requires ${TypeCollection.Ordered.simpleString} type")
     assertError(GreaterThanOrEqual('complexField, 'complexField),
-      s"accepts ${TypeCollection.Ordered.simpleString} type")
+      s"requires ${TypeCollection.Ordered.simpleString} type")
 
     assertError(If('intField, 'stringField, 'stringField),
       "type of predicate expression in If should be boolean")
@@ -145,11 +145,11 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertSuccess(SumDistinct('stringField))
     assertSuccess(Average('stringField))
 
-    assertError(Min('complexField), "doesn't support ordering on function min")
-    assertError(Max('complexField), "doesn't support ordering on function max")
-    assertError(Sum('booleanField), "function sum accepts numeric type")
-    assertError(SumDistinct('booleanField), "function sumDistinct accepts numeric type")
-    assertError(Average('booleanField), "function average accepts numeric type")
+    assertError(Min('complexField), "min does not support ordering on type")
+    assertError(Max('complexField), "max does not support ordering on type")
+    assertError(Sum('booleanField), "function sum requires numeric type")
+    assertError(SumDistinct('booleanField), "function sumDistinct requires numeric type")
+    assertError(Average('booleanField), "function average requires numeric type")
   }
 
   test("check types for others") {
@@ -181,8 +181,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertSuccess(Round('intField, Literal(1)))
 
     assertError(Round('intField, 'intField), "Only foldable Expression is allowed")
-    assertError(Round('intField, 'booleanField), "expected to be of type int")
-    assertError(Round('intField, 'complexField), "expected to be of type int")
-    assertError(Round('booleanField, 'intField), "expected to be of type numeric")
+    assertError(Round('intField, 'booleanField), "requires int type")
+    assertError(Round('intField, 'complexField), "requires int type")
+    assertError(Round('booleanField, 'intField), "requires numeric type")
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 4aff52d992..952ba7d45c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -89,18 +89,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
     }
 
-    private[this] def isValidSort(
-        leftKeys: Seq[Expression],
-        rightKeys: Seq[Expression]): Boolean = {
-      leftKeys.zip(rightKeys).forall { keys =>
-        (keys._1.dataType, keys._2.dataType) match {
-          case (l: AtomicType, r: AtomicType) => true
-          case (NullType, NullType) => true
-          case _ => false
-        }
-      }
-    }
-
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
         makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
@@ -111,7 +99,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       // If the sort merge join option is set, we want to use sort merge join prior to hashjoin
       // for now let's support inner join first, then add outer join
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
-        if sqlContext.conf.sortMergeJoinEnabled && isValidSort(leftKeys, rightKeys) =>
+        if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
         val mergeJoin =
           joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
         condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 46921d1425..431dcf7382 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -305,13 +305,12 @@ class DataFrameFunctionsSuite extends QueryTest {
     val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b")
     assert(intercept[AnalysisException] {
       df2.selectExpr("sort_array(a)").collect()
-    }.getMessage().contains("Type ArrayType(IntegerType,false) is not the AtomicType, " +
-      "we can not perform the ordering operations"))
+    }.getMessage().contains("does not support sorting array of type array<int>"))
 
     val df3 = Seq(("xxx", "x")).toDF("a", "b")
     assert(intercept[AnalysisException] {
       df3.selectExpr("sort_array(a)").collect()
-    }.getMessage().contains("ArrayType(AtomicType) is expected, but we got StringType"))
+    }.getMessage().contains("only supports array input"))
   }
 
   test("array size function") {
-- 
GitLab