diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index b86825902ab3d8100d15dbf2b2d554f20af5122e..b87016d5a5696db8b1e4fe98fdca323679fd6c6a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -65,6 +65,11 @@ case class InMemoryTableScanExec(
     case EqualTo(l: Literal, a: AttributeReference) =>
       statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
 
+    case EqualNullSafe(a: AttributeReference, l: Literal) =>
+      statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
+    case EqualNullSafe(l: Literal, a: AttributeReference) =>
+      statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
+
     case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l
     case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
index b99cd67a6344c34f226b9dbe2c7bd812f3664c37..9d862cfdecb2151c01c8e1af6f8f9a5399bccf20 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
@@ -85,6 +85,8 @@ class PartitionBatchPruningSuite
   // Comparisons
   checkBatchPruning("SELECT key FROM pruningData WHERE key = 1", 1, 1)(Seq(1))
   checkBatchPruning("SELECT key FROM pruningData WHERE 1 = key", 1, 1)(Seq(1))
+  checkBatchPruning("SELECT key FROM pruningData WHERE key <=> 1", 1, 1)(Seq(1))
+  checkBatchPruning("SELECT key FROM pruningData WHERE 1 <=> key", 1, 1)(Seq(1))
   checkBatchPruning("SELECT key FROM pruningData WHERE key < 12", 1, 2)(1 to 11)
   checkBatchPruning("SELECT key FROM pruningData WHERE key <= 11", 1, 2)(1 to 11)
   checkBatchPruning("SELECT key FROM pruningData WHERE key > 88", 1, 2)(89 to 100)