diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 38538630c8b32809971bb61a0f42cb67550c9cad..437397187356ccb9c067865f6ccafa1bcd82ac28 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -594,6 +594,7 @@ class CodegenContext {
     case array: ArrayType => genComp(array, c1, c2) + " == 0"
     case struct: StructType => genComp(struct, c1, c2) + " == 0"
     case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
+    case NullType => "false"
     case _ =>
       throw new IllegalArgumentException(
         "cannot generate equality code for un-comparable type: " + dataType.simpleString)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 613d6202b0b26ede1c99d6e14cf55a028d5b6d5f..d3071c533bef974f40dc6c7935b46cee04200453 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -448,6 +448,16 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
 
 abstract class BinaryComparison extends BinaryOperator with Predicate {
 
+  // Note that we need to give a superset of allowable input types since orderable types are not
+  // finitely enumerable. The allowable types are checked below by checkInputDataTypes.
+  override def inputType: AbstractDataType = AnyDataType
+
+  override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
+    case TypeCheckResult.TypeCheckSuccess =>
+      TypeUtils.checkForOrderingExpr(left.dataType, this.getClass.getSimpleName)
+    case failure => failure
+  }
+
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     if (ctx.isPrimitiveType(left.dataType)
         && left.dataType != BooleanType // java boolean doesn't support > or < operator
@@ -460,7 +470,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
     }
   }
 
-  protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
+  protected lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(left.dataType)
 }
 
 
@@ -478,28 +488,13 @@ object Equality {
   }
 }
 
+// TODO: although map type is not orderable, technically map type should be able to be used
+// in equality comparison
 @ExpressionDescription(
   usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` equals `expr2`, or false otherwise.")
 case class EqualTo(left: Expression, right: Expression)
     extends BinaryComparison with NullIntolerant {
 
-  override def inputType: AbstractDataType = AnyDataType
-
-  override def checkInputDataTypes(): TypeCheckResult = {
-    super.checkInputDataTypes() match {
-      case TypeCheckResult.TypeCheckSuccess =>
-        // TODO: although map type is not orderable, technically map type should be able to be used
-        // in equality comparison, remove this type check once we support it.
-        if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
-          TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, but the actual " +
-            s"input type is ${left.dataType.catalogString}.")
-        } else {
-          TypeCheckResult.TypeCheckSuccess
-        }
-      case failure => failure
-    }
-  }
-
   override def symbol: String = "="
 
   protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right)
@@ -509,6 +504,8 @@ case class EqualTo(left: Expression, right: Expression)
   }
 }
 
+// TODO: although map type is not orderable, technically map type should be able to be used
+// in equality comparison
 @ExpressionDescription(
   usage = """
     expr1 _FUNC_ expr2 - Returns same result as the EQUAL(=) operator for non-null operands,
@@ -516,23 +513,6 @@ case class EqualTo(left: Expression, right: Expression)
   """)
 case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison {
 
-  override def inputType: AbstractDataType = AnyDataType
-
-  override def checkInputDataTypes(): TypeCheckResult = {
-    super.checkInputDataTypes() match {
-      case TypeCheckResult.TypeCheckSuccess =>
-        // TODO: although map type is not orderable, technically map type should be able to be used
-        // in equality comparison, remove this type check once we support it.
-        if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
-          TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualNullSafe, but the actual " +
-            s"input type is ${left.dataType.catalogString}.")
-        } else {
-          TypeCheckResult.TypeCheckSuccess
-        }
-      case failure => failure
-    }
-  }
-
   override def symbol: String = "<=>"
 
   override def nullable: Boolean = false
@@ -564,8 +544,6 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
 case class LessThan(left: Expression, right: Expression)
     extends BinaryComparison with NullIntolerant {
 
-  override def inputType: AbstractDataType = TypeCollection.Ordered
-
   override def symbol: String = "<"
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
@@ -576,8 +554,6 @@ case class LessThan(left: Expression, right: Expression)
 case class LessThanOrEqual(left: Expression, right: Expression)
     extends BinaryComparison with NullIntolerant {
 
-  override def inputType: AbstractDataType = TypeCollection.Ordered
-
   override def symbol: String = "<="
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
@@ -588,8 +564,6 @@ case class LessThanOrEqual(left: Expression, right: Expression)
 case class GreaterThan(left: Expression, right: Expression)
     extends BinaryComparison with NullIntolerant {
 
-  override def inputType: AbstractDataType = TypeCollection.Ordered
-
   override def symbol: String = ">"
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
@@ -600,8 +574,6 @@ case class GreaterThan(left: Expression, right: Expression)
 case class GreaterThanOrEqual(left: Expression, right: Expression)
     extends BinaryComparison with NullIntolerant {
 
-  override def inputType: AbstractDataType = TypeCollection.Ordered
-
   override def symbol: String = ">="
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 45225779bffcb8b9a99dc8e5242da47bc30e3fcd..1dcda49a3af6ae25fdc303f12e48ab3c00faad12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -65,6 +65,7 @@ object TypeUtils {
       case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
       case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
       case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
+      case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType)
     }
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index 1d54ff5825c2e4586cc56373fb1ceb91ef350817..3041f44b116ead5190304887e10b6b20aa8ae1a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -78,18 +78,6 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
 
 private[sql] object TypeCollection {
 
-  /**
-   * Types that can be ordered/compared. In the long run we should probably make this a trait
-   * that can be mixed into each data type, and perhaps create an `AbstractDataType`.
-   */
-  // TODO: Should we consolidate this with RowOrdering.isOrderable?
-  val Ordered = TypeCollection(
-    BooleanType,
-    ByteType, ShortType, IntegerType, LongType,
-    FloatType, DoubleType, DecimalType,
-    TimestampType, DateType,
-    StringType, BinaryType)
-
   /**
    * Types that include numeric types and interval type. They are only used in unary_minus,
    * unary_positive, add and subtract operations.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 4e0613619add617370b06f8a347ed1535adbf279..884e113537c934b072abfdde946389b97a0f2787 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -505,7 +505,7 @@ class AnalysisErrorSuite extends AnalysisTest {
       right,
       joinType = Cross,
       condition = Some('b === 'd))
-    assertAnalysisError(plan2, "Cannot use map type in EqualTo" :: Nil)
+    assertAnalysisError(plan2, "EqualTo does not support ordering on type MapType" :: Nil)
   }
 
   test("PredicateSubQuery is used outside of a filter") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 30725773a37b1d4bd49493868403ae585bbdebc6..36714bd631b0ee7ed766f44fd29b1ec027f8cb1e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.types.{LongType, StringType, TypeCollection}
+import org.apache.spark.sql.types._
 
 class ExpressionTypeCheckingSuite extends SparkFunSuite {
 
@@ -109,16 +109,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
     assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
 
-    assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo")
-    assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe")
+    assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type MapType")
+    assertError(EqualNullSafe('mapField, 'mapField),
+      "EqualNullSafe does not support ordering on type MapType")
     assertError(LessThan('mapField, 'mapField),
-      s"requires ${TypeCollection.Ordered.simpleString} type")
+      "LessThan does not support ordering on type MapType")
     assertError(LessThanOrEqual('mapField, 'mapField),
-      s"requires ${TypeCollection.Ordered.simpleString} type")
+      "LessThanOrEqual does not support ordering on type MapType")
     assertError(GreaterThan('mapField, 'mapField),
-      s"requires ${TypeCollection.Ordered.simpleString} type")
+      "GreaterThan does not support ordering on type MapType")
     assertError(GreaterThanOrEqual('mapField, 'mapField),
-      s"requires ${TypeCollection.Ordered.simpleString} type")
+      "GreaterThanOrEqual does not support ordering on type MapType")
 
     assertError(If('intField, 'stringField, 'stringField),
       "type of predicate expression in If should be boolean")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index ef510a95ef446e3b35f250f803ab587d7c60561d..055c31c2b30185a3de38666b6b377ee9403ac84b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -17,12 +17,15 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import java.sql.{Date, Timestamp}
+
 import scala.collection.immutable.HashSet
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.RandomDataGenerator
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
 import org.apache.spark.sql.types._
 
 
@@ -215,14 +218,35 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
     }
   }
 
-  private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_))
+  private case class MyStruct(a: Long, b: String)
+  private case class MyStruct2(a: MyStruct, b: Array[Int])
+  private val udt = new ExamplePointUDT
+
+  private val smallValues =
+    Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
+      new Timestamp(1), "a", 1f, 1d, 0f, 0d, false, Array(1L, 2L))
+      .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
+      Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
+      Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
   private val largeValues =
-    Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, true).map(Literal(_))
+    Seq(2.toByte, 2.toShort, 2, 2L, Decimal(2), Array(2.toByte), new Date(2000, 1, 2),
+      new Timestamp(2), "b", 2f, 2d, Float.NaN, Double.NaN, true, Array(2L, 1L))
+      .map(Literal(_)) ++ Seq(Literal.create(MyStruct(2L, "b")),
+      Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 2))),
+      Literal.create(ArrayData.toArrayData(Array(1.0, 3.0)), udt))
 
   private val equalValues1 =
-    Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_))
+    Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
+      new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L))
+      .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
+      Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
+      Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
   private val equalValues2 =
-    Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_))
+    Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
+      new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L))
+      .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
+      Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
+      Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
 
   test("BinaryComparison consistency check") {
     DataTypeTestUtils.ordered.foreach { dt =>
@@ -285,11 +309,13 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
     // Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757
     val normalInt = Literal(-1)
     val nullInt = NonFoldableLiteral.create(null, IntegerType)
+    val nullNullType = Literal.create(null, NullType)
 
     def nullTest(op: (Expression, Expression) => Expression): Unit = {
       checkEvaluation(op(normalInt, nullInt), null)
       checkEvaluation(op(nullInt, normalInt), null)
       checkEvaluation(op(nullInt, nullInt), null)
+      checkEvaluation(op(nullNullType, nullNullType), null)
     }
 
     nullTest(LessThan)
@@ -301,6 +327,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(EqualNullSafe(normalInt, nullInt), false)
     checkEvaluation(EqualNullSafe(nullInt, normalInt), false)
     checkEvaluation(EqualNullSafe(nullInt, nullInt), true)
+    checkEvaluation(EqualNullSafe(nullNullType, nullNullType), true)
   }
 
   test("EqualTo on complex type") {