From d314677cfd9cb4140005765938841bae9dc48a2d Mon Sep 17 00:00:00 2001
From: hyukjinkwon <gurwls223@gmail.com>
Date: Thu, 1 Sep 2016 15:32:07 -0700
Subject: [PATCH] [SPARK-16461][SQL] Support partition batch pruning with `<=>`
 predicate in InMemoryTableScanExec

## What changes were proposed in this pull request?

It seems `EqualNullSafe` filter was missed for batch pruneing partitions in cached tables.

It seems supporting this improves the performance roughly 5 times faster.

Running the codes below:

```scala
test("Null-safe equal comparison") {
  val N = 20000000
  val df = spark.range(N).repartition(20)
  val benchmark = new Benchmark("Null-safe equal comparison", N)
  df.createOrReplaceTempView("t")
  spark.catalog.cacheTable("t")
  sql("select id from t where id <=> 1").collect()

  benchmark.addCase("Null-safe equal comparison", 10) { _ =>
    sql("select id from t where id <=> 1").collect()
  }
  benchmark.run()
}
```

produces the results below:

**Before:**

```
Running benchmark: Null-safe equal comparison
  Running case: Null-safe equal comparison
  Stopped after 10 iterations, 2098 ms

Java HotSpot(TM) 64-Bit Server VM 1.8.0_45-b14 on Mac OS X 10.11.5
Intel(R) Core(TM) i7-4850HQ CPU  2.30GHz

Null-safe equal comparison:              Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Null-safe equal comparison                     204 /  210         98.1          10.2       1.0X
```

**After:**

```
Running benchmark: Null-safe equal comparison
  Running case: Null-safe equal comparison
  Stopped after 10 iterations, 478 ms

Java HotSpot(TM) 64-Bit Server VM 1.8.0_45-b14 on Mac OS X 10.11.5
Intel(R) Core(TM) i7-4850HQ CPU  2.30GHz

Null-safe equal comparison:              Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Null-safe equal comparison                      42 /   48        474.1           2.1       1.0X
```

## How was this patch tested?

Unit tests in `PartitionBatchPruningSuite`.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #14117 from HyukjinKwon/SPARK-16461.
---
 .../spark/sql/execution/columnar/InMemoryTableScanExec.scala | 5 +++++
 .../sql/execution/columnar/PartitionBatchPruningSuite.scala  | 2 ++
 2 files changed, 7 insertions(+)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index b86825902a..b87016d5a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -65,6 +65,11 @@ case class InMemoryTableScanExec(
     case EqualTo(l: Literal, a: AttributeReference) =>
       statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
 
+    case EqualNullSafe(a: AttributeReference, l: Literal) =>
+      statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
+    case EqualNullSafe(l: Literal, a: AttributeReference) =>
+      statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
+
     case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l
     case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
index b99cd67a63..9d862cfdec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
@@ -85,6 +85,8 @@ class PartitionBatchPruningSuite
   // Comparisons
   checkBatchPruning("SELECT key FROM pruningData WHERE key = 1", 1, 1)(Seq(1))
   checkBatchPruning("SELECT key FROM pruningData WHERE 1 = key", 1, 1)(Seq(1))
+  checkBatchPruning("SELECT key FROM pruningData WHERE key <=> 1", 1, 1)(Seq(1))
+  checkBatchPruning("SELECT key FROM pruningData WHERE 1 <=> key", 1, 1)(Seq(1))
   checkBatchPruning("SELECT key FROM pruningData WHERE key < 12", 1, 2)(1 to 11)
   checkBatchPruning("SELECT key FROM pruningData WHERE key <= 11", 1, 2)(1 to 11)
   checkBatchPruning("SELECT key FROM pruningData WHERE key > 88", 1, 2)(89 to 100)
-- 
GitLab