diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 98e50d0d3c6741d04619ad64ab085977f4d71b6a..80e577e5c4c798f88579b9fc31dfc04e058cce4d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -183,21 +183,6 @@ trait CheckAnalysis extends PredicateHelper {
               s"join condition '${condition.sql}' " +
                 s"of type ${condition.dataType.simpleString} is not a boolean.")
 
-          case j @ Join(_, _, _, Some(condition)) =>
-            def checkValidJoinConditionExprs(expr: Expression): Unit = expr match {
-              case p: Predicate =>
-                p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs)
-              case e if e.dataType.isInstanceOf[BinaryType] =>
-                failAnalysis(s"binary type expression ${e.sql} cannot be used " +
-                  "in join conditions")
-              case e if e.dataType.isInstanceOf[MapType] =>
-                failAnalysis(s"map type expression ${e.sql} cannot be used " +
-                  "in join conditions")
-              case _ => // OK
-            }
-
-            checkValidJoinConditionExprs(condition)
-
           case Aggregate(groupingExprs, aggregateExprs, child) =>
             def checkValidAggregateExpression(expr: Expression): Unit = expr match {
               case aggExpr: AggregateExpression =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 7946c201f4ffc20bb4e46ee5d83091e58189a612..2ad452b6a90cac2d920ccd6664ed9576f1863b35 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -412,6 +412,21 @@ case class EqualTo(left: Expression, right: Expression)
 
   override def inputType: AbstractDataType = AnyDataType
 
+  override def checkInputDataTypes(): TypeCheckResult = {
+    super.checkInputDataTypes() match {
+      case TypeCheckResult.TypeCheckSuccess =>
+        // TODO: although map type is not orderable, technically map type should be able to be used
+        // in equality comparison, remove this type check once we support it.
+        if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
+          TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, but the actual " +
+            s"input type is ${left.dataType.catalogString}.")
+        } else {
+          TypeCheckResult.TypeCheckSuccess
+        }
+      case failure => failure
+    }
+  }
+
   override def symbol: String = "="
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = {
@@ -440,6 +455,21 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
 
   override def inputType: AbstractDataType = AnyDataType
 
+  override def checkInputDataTypes(): TypeCheckResult = {
+    super.checkInputDataTypes() match {
+      case TypeCheckResult.TypeCheckSuccess =>
+        // TODO: although map type is not orderable, technically map type should be able to be used
+        // in equality comparison, remove this type check once we support it.
+        if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
+          TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualNullSafe, but the actual " +
+            s"input type is ${left.dataType.catalogString}.")
+        } else {
+          TypeCheckResult.TypeCheckSuccess
+        }
+      case failure => failure
+    }
+  }
+
   override def symbol: String = "<=>"
 
   override def nullable: Boolean = false
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 21afe9fec5944a353e930bf1b7cc2664a4d24351..8c1faea2394c6e2e7be1e05c663ba36062e251b3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -465,34 +465,22 @@ class AnalysisErrorSuite extends AnalysisTest {
         "another aggregate function." :: Nil)
   }
 
-  test("Join can't work on binary and map types") {
-    val plan =
-      Join(
-        LocalRelation(
-          AttributeReference("a", BinaryType)(exprId = ExprId(2)),
-          AttributeReference("b", IntegerType)(exprId = ExprId(1))),
-        LocalRelation(
-          AttributeReference("c", BinaryType)(exprId = ExprId(4)),
-          AttributeReference("d", IntegerType)(exprId = ExprId(3))),
-        Cross,
-        Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)),
-          AttributeReference("c", BinaryType)(exprId = ExprId(4)))))
-
-    assertAnalysisError(plan, "binary type expression `a` cannot be used in join conditions" :: Nil)
-
-    val plan2 =
-      Join(
-        LocalRelation(
-          AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
-          AttributeReference("b", IntegerType)(exprId = ExprId(1))),
-        LocalRelation(
-          AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)),
-          AttributeReference("d", IntegerType)(exprId = ExprId(3))),
-        Cross,
-        Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
-          AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)))))
-
-    assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil)
+  test("Join can work on binary types but can't work on map types") {
+    val left = LocalRelation('a.binary, 'b.map(StringType, StringType))
+    val right = LocalRelation('c.binary, 'd.map(StringType, StringType))
+
+    val plan1 = left.join(
+      right,
+      joinType = Cross,
+      condition = Some('a === 'c))
+
+    assertAnalysisSuccess(plan1)
+
+    val plan2 = left.join(
+      right,
+      joinType = Cross,
+      condition = Some('b === 'd))
+    assertAnalysisError(plan2, "Cannot use map type in EqualTo" :: Nil)
   }
 
   test("PredicateSubQuery is used outside of a filter") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 542e654bbce12b168cd31109c4176f26277d9b1f..744057b7c5f4cfa1e4b751ebd86ca655343a7973 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -111,6 +111,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
     assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
 
+    assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo")
+    assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe")
     assertError(LessThan('mapField, 'mapField),
       s"requires ${TypeCollection.Ordered.simpleString} type")
     assertError(LessThanOrEqual('mapField, 'mapField),