From 2a36292534a1e9f7a501e88f69bfc3a09fb62cb3 Mon Sep 17 00:00:00 2001
From: Lu Yan <luyan02@baidu.com>
Date: Mon, 9 Feb 2015 16:25:38 -0800
Subject: [PATCH] [SPARK-5614][SQL] Predicate pushdown through Generate.

Now in Catalyst's rules, predicates can not be pushed through "Generate" nodes. Further more, partition pruning in HiveTableScan can not be applied on those queries involves "Generate". This makes such queries very inefficient. In practice, it finds patterns like

```scala
Filter(predicate, Generate(generator, _, _, _, grandChild))
```

and splits the predicate into 2 parts by referencing the generated column from Generate node or not. And a new Filter will be created for those conjuncts can be pushed beneath Generate node. If nothing left for the original Filter, it will be removed.
For example, physical plan for query
```sql
select len, bk
from s_server lateral view explode(len_arr) len_table as len
where len > 5 and day = '20150102';
```
where 'day' is a partition column in metastore is like this in current version of Spark SQL:

> Project [len, bk]
>
> Filter ((len > "5") && "(day = "20150102")")
>
> Generate explode(len_arr), true, false
>
> HiveTableScan [bk, len_arr, day], (MetastoreRelation default, s_server, None), None

But theoretically the plan should be like this

> Project [len, bk]
>
> Filter (len > "5")
>
> Generate explode(len_arr), true, false
>
> HiveTableScan [bk, len_arr, day], (MetastoreRelation default, s_server, None), Some(day = "20150102")

Where partition pruning predicates can be pushed to HiveTableScan nodes.

Author: Lu Yan <luyan02@baidu.com>

Closes #4394 from ianluyan/ppd and squashes the following commits:

a67dce9 [Lu Yan] Fix English grammar.
7cea911 [Lu Yan] Revised based on @marmbrus's opinions
ffc59fc [Lu Yan] [SPARK-5614][SQL] Predicate pushdown through Generate.
---
 .../sql/catalyst/optimizer/Optimizer.scala    | 25 ++++++++
 .../optimizer/FilterPushdownSuite.scala       | 63 ++++++++++++++++++-
 2 files changed, 87 insertions(+), 1 deletion(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 3bc48c95c5..fd58b9681e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -50,6 +50,7 @@ object DefaultOptimizer extends Optimizer {
       CombineFilters,
       PushPredicateThroughProject,
       PushPredicateThroughJoin,
+      PushPredicateThroughGenerate,
       ColumnPruning) ::
     Batch("LocalRelation", FixedPoint(100),
       ConvertToLocalRelation) :: Nil
@@ -455,6 +456,30 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
   }
 }
 
+/**
+ * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference
+ * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath.
+ */
+object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case filter @ Filter(condition,
+    generate @ Generate(generator, join, outer, alias, grandChild)) =>
+      // Predicates that reference attributes produced by the `Generate` operator cannot
+      // be pushed below the operator.
+      val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
+        conjunct => conjunct.references subsetOf grandChild.outputSet
+      }
+      if (pushDown.nonEmpty) {
+        val pushDownPredicate = pushDown.reduce(And)
+        val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild))
+        stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
+      } else {
+        filter
+      }
+  }
+}
+
 /**
  * Pushes down [[Filter]] operators where the `condition` can be
  * evaluated using only the attributes of the left or right side of a join.  Other
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index ebb123c1f9..1158b5dfc6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
+import org.apache.spark.sql.catalyst.expressions.Explode
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter}
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.types.IntegerType
 
 class FilterPushdownSuite extends PlanTest {
 
@@ -34,7 +36,8 @@ class FilterPushdownSuite extends PlanTest {
       Batch("Filter Pushdown", Once,
         CombineFilters,
         PushPredicateThroughProject,
-        PushPredicateThroughJoin) :: Nil
+        PushPredicateThroughJoin,
+        PushPredicateThroughGenerate) :: Nil
   }
 
   val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
@@ -411,4 +414,62 @@ class FilterPushdownSuite extends PlanTest {
 
     comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer))
   }
+
+  val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
+
+  test("generate: predicate referenced no generated column") {
+    val originalQuery = {
+      testRelationWithArrayType
+        .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
+        .where(('b >= 5) && ('a > 6))
+    }
+    val optimized = Optimize(originalQuery.analyze)
+    val correctAnswer = {
+      testRelationWithArrayType
+        .where(('b >= 5) && ('a > 6))
+        .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze
+    }
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("generate: part of conjuncts referenced generated column") {
+    val generator = Explode(Seq("c"), 'c_arr)
+    val originalQuery = {
+      testRelationWithArrayType
+        .generate(generator, true, false, Some("arr"))
+        .where(('b >= 5) && ('c > 6))
+    }
+    val optimized = Optimize(originalQuery.analyze)
+    val referenceResult = {
+      testRelationWithArrayType
+        .where('b >= 5)
+        .generate(generator, true, false, Some("arr"))
+        .where('c > 6).analyze
+    }
+
+    // Since newly generated columns get different ids every time being analyzed
+    // e.g. comparePlans(originalQuery.analyze, originalQuery.analyze) fails.
+    // So we check operators manually here.
+    // Filter("c" > 6)
+    assertResult(classOf[Filter])(optimized.getClass)
+    assertResult(1)(optimized.asInstanceOf[Filter].condition.references.size)
+    assertResult("c"){
+      optimized.asInstanceOf[Filter].condition.references.toSeq(0).name
+    }
+
+    // the rest part
+    comparePlans(optimized.children(0), referenceResult.children(0))
+  }
+
+  test("generate: all conjuncts referenced generated column") {
+    val originalQuery = {
+      testRelationWithArrayType
+        .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
+        .where(('c > 6) || ('b > 5)).analyze
+    }
+    val optimized = Optimize(originalQuery)
+
+    comparePlans(optimized, originalQuery)
+  }
 }
-- 
GitLab