From 22be2ae147a111e88896f6fb42ed46bbf108a99b Mon Sep 17 00:00:00 2001
From: Cheng Lian <lian@databricks.com>
Date: Fri, 18 Sep 2015 18:42:20 -0700
Subject: [PATCH] [SPARK-10623] [SQL] Fixes ORC predicate push-down

When pushing down a leaf predicate, ORC `SearchArgument` builder requires an extra "parent" predicate (any one among `AND`/`OR`/`NOT`) to wrap the leaf predicate. E.g., to push down `a < 1`, we must build `AND(a < 1)` instead. Fortunately, when actually constructing the `SearchArgument`, the builder will eliminate all those unnecessary wrappers.

This PR is based on #8783 authored by zhzhan. I also took the chance to simply `OrcFilters` a little bit to improve readability.

Author: Cheng Lian <lian@databricks.com>

Closes #8799 from liancheng/spark-10623/fix-orc-ppd.
---
 .../spark/sql/hive/orc/OrcFilters.scala       | 56 ++++++++-----------
 .../spark/sql/hive/orc/OrcQuerySuite.scala    | 30 ++++++++++
 2 files changed, 52 insertions(+), 34 deletions(-)

diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala
index b3d9f7f71a..27193f54d3 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala
@@ -31,11 +31,13 @@ import org.apache.spark.sql.sources._
  * and cannot be used anymore.
  */
 private[orc] object OrcFilters extends Logging {
-  def createFilter(expr: Array[Filter]): Option[SearchArgument] = {
-    expr.reduceOption(And).flatMap { conjunction =>
-      val builder = SearchArgumentFactory.newBuilder()
-      buildSearchArgument(conjunction, builder).map(_.build())
-    }
+  def createFilter(filters: Array[Filter]): Option[SearchArgument] = {
+    for {
+      // Combines all filters with `And`s to produce a single conjunction predicate
+      conjunction <- filters.reduceOption(And)
+      // Then tries to build a single ORC `SearchArgument` for the conjunction predicate
+      builder <- buildSearchArgument(conjunction, SearchArgumentFactory.newBuilder())
+    } yield builder.build()
   }
 
   private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = {
@@ -102,46 +104,32 @@ private[orc] object OrcFilters extends Logging {
           negate <- buildSearchArgument(child, builder.startNot())
         } yield negate.end()
 
-      case EqualTo(attribute, value) =>
-        Option(value)
-          .filter(isSearchableLiteral)
-          .map(builder.equals(attribute, _))
+      case EqualTo(attribute, value) if isSearchableLiteral(value) =>
+        Some(builder.startAnd().equals(attribute, value).end())
 
-      case EqualNullSafe(attribute, value) =>
-        Option(value)
-          .filter(isSearchableLiteral)
-          .map(builder.nullSafeEquals(attribute, _))
+      case EqualNullSafe(attribute, value) if isSearchableLiteral(value) =>
+        Some(builder.startAnd().nullSafeEquals(attribute, value).end())
 
-      case LessThan(attribute, value) =>
-        Option(value)
-          .filter(isSearchableLiteral)
-          .map(builder.lessThan(attribute, _))
+      case LessThan(attribute, value) if isSearchableLiteral(value) =>
+        Some(builder.startAnd().lessThan(attribute, value).end())
 
-      case LessThanOrEqual(attribute, value) =>
-        Option(value)
-          .filter(isSearchableLiteral)
-          .map(builder.lessThanEquals(attribute, _))
+      case LessThanOrEqual(attribute, value) if isSearchableLiteral(value) =>
+        Some(builder.startAnd().lessThanEquals(attribute, value).end())
 
-      case GreaterThan(attribute, value) =>
-        Option(value)
-          .filter(isSearchableLiteral)
-          .map(builder.startNot().lessThanEquals(attribute, _).end())
+      case GreaterThan(attribute, value) if isSearchableLiteral(value) =>
+        Some(builder.startNot().lessThanEquals(attribute, value).end())
 
-      case GreaterThanOrEqual(attribute, value) =>
-        Option(value)
-          .filter(isSearchableLiteral)
-          .map(builder.startNot().lessThan(attribute, _).end())
+      case GreaterThanOrEqual(attribute, value) if isSearchableLiteral(value) =>
+        Some(builder.startNot().lessThan(attribute, value).end())
 
       case IsNull(attribute) =>
-        Some(builder.isNull(attribute))
+        Some(builder.startAnd().isNull(attribute).end())
 
       case IsNotNull(attribute) =>
         Some(builder.startNot().isNull(attribute).end())
 
-      case In(attribute, values) =>
-        Option(values)
-          .filter(_.forall(isSearchableLiteral))
-          .map(builder.in(attribute, _))
+      case In(attribute, values) if values.forall(isSearchableLiteral) =>
+        Some(builder.startAnd().in(attribute, values.map(_.asInstanceOf[AnyRef]): _*).end())
 
       case _ => None
     }
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
index 8bc33fcf5d..5eb39b1129 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
@@ -344,4 +344,34 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
       }
     }
   }
+
+  test("SPARK-10623 Enable ORC PPD") {
+    withTempPath { dir =>
+      withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") {
+        import testImplicits._
+
+        val path = dir.getCanonicalPath
+        sqlContext.range(10).coalesce(1).write.orc(path)
+        val df = sqlContext.read.orc(path)
+
+        def checkPredicate(pred: Column, answer: Seq[Long]): Unit = {
+          checkAnswer(df.where(pred), answer.map(Row(_)))
+        }
+
+        checkPredicate('id === 5, Seq(5L))
+        checkPredicate('id <=> 5, Seq(5L))
+        checkPredicate('id < 5, 0L to 4L)
+        checkPredicate('id <= 5, 0L to 5L)
+        checkPredicate('id > 5, 6L to 9L)
+        checkPredicate('id >= 5, 5L to 9L)
+        checkPredicate('id.isNull, Seq.empty[Long])
+        checkPredicate('id.isNotNull, 0L to 9L)
+        checkPredicate('id.isin(1L, 3L, 5L), Seq(1L, 3L, 5L))
+        checkPredicate('id > 0 && 'id < 3, 1L to 2L)
+        checkPredicate('id < 1 || 'id > 8, Seq(0L, 9L))
+        checkPredicate(!('id > 3), 0L to 3L)
+        checkPredicate(!('id > 0 && 'id < 3), Seq(0L) ++ (3L to 9L))
+      }
+    }
+  }
 }
-- 
GitLab