diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 183e4947b6d7216b2aa6ee246f9371a1a4e8af1e..67a410f539b60d5349fda3dd1f49597f30e196d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -79,6 +79,10 @@ private[sql] case class InMemoryTableScanExec(
 
     case IsNull(a: Attribute) => statsFor(a).nullCount > 0
     case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0
+
+    case In(a: AttributeReference, list: Seq[Expression]) if list.forall(_.isInstanceOf[Literal]) =>
+      list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] &&
+        l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _)
   }
 
   val partitionFilters: Seq[Expression] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
index 7ca8e047f081d9cd5ebc8b28bf0a91ed53dcaea2..b99cd67a6344c34f226b9dbe2c7bd812f3664c37 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
@@ -65,11 +65,18 @@ class PartitionBatchPruningSuite
     }, 5).toDF()
     pruningData.createOrReplaceTempView("pruningData")
     spark.catalog.cacheTable("pruningData")
+
+    val pruningStringData = sparkContext.makeRDD((100 to 200).map { key =>
+      StringData(key.toString)
+    }, 5).toDF()
+    pruningStringData.createOrReplaceTempView("pruningStringData")
+    spark.catalog.cacheTable("pruningStringData")
   }
 
   override protected def afterEach(): Unit = {
     try {
       spark.catalog.uncacheTable("pruningData")
+      spark.catalog.uncacheTable("pruningStringData")
     } finally {
       super.afterEach()
     }
@@ -110,9 +117,23 @@ class PartitionBatchPruningSuite
     88 to 100
   }
 
-  // With unsupported predicate
+  // Support `IN` predicate
+  checkBatchPruning("SELECT key FROM pruningData WHERE key IN (1)", 1, 1)(Seq(1))
+  checkBatchPruning("SELECT key FROM pruningData WHERE key IN (1, 2)", 1, 1)(Seq(1, 2))
+  checkBatchPruning("SELECT key FROM pruningData WHERE key IN (1, 11)", 1, 2)(Seq(1, 11))
+  checkBatchPruning("SELECT key FROM pruningData WHERE key IN (1, 21, 41, 61, 81)", 5, 5)(
+    Seq(1, 21, 41, 61, 81))
+  checkBatchPruning("SELECT CAST(s AS INT) FROM pruningStringData WHERE s = '100'", 1, 1)(Seq(100))
+  checkBatchPruning("SELECT CAST(s AS INT) FROM pruningStringData WHERE s < '102'", 1, 1)(
+    Seq(100, 101))
+  checkBatchPruning(
+    "SELECT CAST(s AS INT) FROM pruningStringData WHERE s IN ('99', '150', '201')", 1, 1)(
+      Seq(150))
+
+  // With unsupported `InSet` predicate
   {
     val seq = (1 to 30).mkString(", ")
+    checkBatchPruning(s"SELECT key FROM pruningData WHERE key IN ($seq)", 5, 10)(1 to 30)
     checkBatchPruning(s"SELECT key FROM pruningData WHERE NOT (key IN ($seq))", 5, 10)(31 to 100)
     checkBatchPruning(s"SELECT key FROM pruningData WHERE NOT (key IN ($seq)) AND key > 88", 1, 2) {
       89 to 100