diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 183e4947b6d7216b2aa6ee246f9371a1a4e8af1e..67a410f539b60d5349fda3dd1f49597f30e196d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -79,6 +79,10 @@ private[sql] case class InMemoryTableScanExec( case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 + + case In(a: AttributeReference, list: Seq[Expression]) if list.forall(_.isInstanceOf[Literal]) => + list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && + l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } val partitionFilters: Seq[Expression] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 7ca8e047f081d9cd5ebc8b28bf0a91ed53dcaea2..b99cd67a6344c34f226b9dbe2c7bd812f3664c37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -65,11 +65,18 @@ class PartitionBatchPruningSuite }, 5).toDF() pruningData.createOrReplaceTempView("pruningData") spark.catalog.cacheTable("pruningData") + + val pruningStringData = sparkContext.makeRDD((100 to 200).map { key => + StringData(key.toString) + }, 5).toDF() + pruningStringData.createOrReplaceTempView("pruningStringData") + spark.catalog.cacheTable("pruningStringData") } override protected def afterEach(): Unit = { try { spark.catalog.uncacheTable("pruningData") + spark.catalog.uncacheTable("pruningStringData") } finally { super.afterEach() } @@ -110,9 +117,23 @@ class PartitionBatchPruningSuite 88 to 100 } - // With unsupported predicate + // Support `IN` predicate + checkBatchPruning("SELECT key FROM pruningData WHERE key IN (1)", 1, 1)(Seq(1)) + checkBatchPruning("SELECT key FROM pruningData WHERE key IN (1, 2)", 1, 1)(Seq(1, 2)) + checkBatchPruning("SELECT key FROM pruningData WHERE key IN (1, 11)", 1, 2)(Seq(1, 11)) + checkBatchPruning("SELECT key FROM pruningData WHERE key IN (1, 21, 41, 61, 81)", 5, 5)( + Seq(1, 21, 41, 61, 81)) + checkBatchPruning("SELECT CAST(s AS INT) FROM pruningStringData WHERE s = '100'", 1, 1)(Seq(100)) + checkBatchPruning("SELECT CAST(s AS INT) FROM pruningStringData WHERE s < '102'", 1, 1)( + Seq(100, 101)) + checkBatchPruning( + "SELECT CAST(s AS INT) FROM pruningStringData WHERE s IN ('99', '150', '201')", 1, 1)( + Seq(150)) + + // With unsupported `InSet` predicate { val seq = (1 to 30).mkString(", ") + checkBatchPruning(s"SELECT key FROM pruningData WHERE key IN ($seq)", 5, 10)(1 to 30) checkBatchPruning(s"SELECT key FROM pruningData WHERE NOT (key IN ($seq))", 5, 10)(31 to 100) checkBatchPruning(s"SELECT key FROM pruningData WHERE NOT (key IN ($seq)) AND key > 88", 1, 2) { 89 to 100