From db0c038a66cb228bcb62a5607cd0ed013d0f9f20 Mon Sep 17 00:00:00 2001
From: Cheng Hao <hao.cheng@intel.com>
Date: Tue, 10 Jun 2014 12:59:52 -0700
Subject: [PATCH] [SPARK-2076][SQL] Pushdown the join filter & predication for
 outer join

As the rule described in https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior, we can optimize the SQL Join by pushing down the Join predicate and Where predicate.

Author: Cheng Hao <hao.cheng@intel.com>

Closes #1015 from chenghao-intel/join_predicate_push_down and squashes the following commits:

10feff9 [Cheng Hao] fix bug of changing the join type in PredicatePushDownThroughJoin
44c6700 [Cheng Hao] Add logical to support pushdown the join filter
0bce426 [Cheng Hao] Pushdown the join filter & predicate for outer join
---
 .../sql/catalyst/optimizer/Optimizer.scala    | 112 +++++++++--
 .../optimizer/FilterPushdownSuite.scala       | 187 +++++++++++++++++-
 2 files changed, 277 insertions(+), 22 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 406ffd6801..ccb8245cc2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -19,6 +19,10 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.FullOuter
+import org.apache.spark.sql.catalyst.plans.LeftOuter
+import org.apache.spark.sql.catalyst.plans.RightOuter
+import org.apache.spark.sql.catalyst.plans.LeftSemi
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.catalyst.types._
@@ -34,7 +38,7 @@ object Optimizer extends RuleExecutor[LogicalPlan] {
     Batch("Filter Pushdown", FixedPoint(100),
       CombineFilters,
       PushPredicateThroughProject,
-      PushPredicateThroughInnerJoin,
+      PushPredicateThroughJoin,
       ColumnPruning) :: Nil
 }
 
@@ -254,28 +258,98 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
 
 /**
  * Pushes down [[catalyst.plans.logical.Filter Filter]] operators where the `condition` can be
- * evaluated using only the attributes of the left or right side of an inner join.  Other
+ * evaluated using only the attributes of the left or right side of a join.  Other
  * [[catalyst.plans.logical.Filter Filter]] conditions are moved into the `condition` of the
  * [[catalyst.plans.logical.Join Join]].
+ * And also Pushes down the join filter, where the `condition` can be evaluated using only the 
+ * attributes of the left or right side of sub query when applicable. 
+ * 
+ * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details
  */
-object PushPredicateThroughInnerJoin extends Rule[LogicalPlan] with PredicateHelper {
+object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
+  // split the condition expression into 3 parts, 
+  // (canEvaluateInLeftSide, canEvaluateInRightSide, haveToEvaluateWithBothSide) 
+  private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
+    val (leftEvaluateCondition, rest) =
+        condition.partition(_.references subsetOf left.outputSet)
+    val (rightEvaluateCondition, commonCondition) = 
+        rest.partition(_.references subsetOf right.outputSet)
+
+    (leftEvaluateCondition, rightEvaluateCondition, commonCondition)
+  }
+
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case f @ Filter(filterCondition, Join(left, right, Inner, joinCondition)) =>
-      val allConditions =
-        splitConjunctivePredicates(filterCondition) ++
-          joinCondition.map(splitConjunctivePredicates).getOrElse(Nil)
-
-      // Split the predicates into those that can be evaluated on the left, right, and those that
-      // must be evaluated after the join.
-      val (rightConditions, leftOrJoinConditions) =
-        allConditions.partition(_.references subsetOf right.outputSet)
-      val (leftConditions, joinConditions) =
-        leftOrJoinConditions.partition(_.references subsetOf left.outputSet)
-
-      // Build the new left and right side, optionally with the pushed down filters.
-      val newLeft = leftConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
-      val newRight = rightConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
-      Join(newLeft, newRight, Inner, joinConditions.reduceLeftOption(And))
+    // push the where condition down into join filter
+    case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) =>
+      val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = 
+        split(splitConjunctivePredicates(filterCondition), left, right)
+
+      joinType match {
+        case Inner =>
+          // push down the single side `where` condition into respective sides
+          val newLeft = leftFilterConditions.
+            reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
+          val newRight = rightFilterConditions.
+            reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+          val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And)
+
+          Join(newLeft, newRight, Inner, newJoinCond)
+        case RightOuter =>
+          // push down the right side only `where` condition
+          val newLeft = left
+          val newRight = rightFilterConditions.
+            reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+          val newJoinCond = joinCondition
+          val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond)
+
+          (leftFilterConditions ++ commonFilterCondition).
+            reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
+        case _ @ (LeftOuter | LeftSemi) =>
+          // push down the left side only `where` condition
+          val newLeft = leftFilterConditions.
+            reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
+          val newRight = right
+          val newJoinCond = joinCondition
+          val newJoin = Join(newLeft, newRight, joinType, newJoinCond)
+
+          (rightFilterConditions ++ commonFilterCondition).
+            reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
+        case FullOuter => f // DO Nothing for Full Outer Join
+      }
+
+    // push down the join filter into sub query scanning if applicable
+    case f @ Join(left, right, joinType, joinCondition) =>
+      val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = 
+        split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
+
+      joinType match {
+        case Inner =>
+          // push down the single side only join filter for both sides sub queries
+          val newLeft = leftJoinConditions.
+            reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
+          val newRight = rightJoinConditions.
+            reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+          val newJoinCond = commonJoinCondition.reduceLeftOption(And)
+
+          Join(newLeft, newRight, Inner, newJoinCond)
+        case RightOuter =>
+          // push down the left side only join filter for left side sub query
+          val newLeft = leftJoinConditions.
+            reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
+          val newRight = right
+          val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
+
+          Join(newLeft, newRight, RightOuter, newJoinCond)
+        case _ @ (LeftOuter | LeftSemi) =>
+          // push down the right side only join filter for right sub query
+          val newLeft = left
+          val newRight = rightJoinConditions.
+            reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+          val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
+
+          Join(newLeft, newRight, joinType, newJoinCond)
+        case FullOuter => f
+      }
   }
 }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index ef47850455..02cc665f8a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -20,11 +20,14 @@ package org.apache.spark.sql.catalyst.optimizer
 import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.FullOuter
+import org.apache.spark.sql.catalyst.plans.LeftOuter
+import org.apache.spark.sql.catalyst.plans.RightOuter
 import org.apache.spark.sql.catalyst.rules._
-
-/* Implicit conversions */
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.junit.Test
 
 class FilterPushdownSuite extends OptimizerTest {
 
@@ -35,7 +38,7 @@ class FilterPushdownSuite extends OptimizerTest {
       Batch("Filter Pushdown", Once,
         CombineFilters,
         PushPredicateThroughProject,
-        PushPredicateThroughInnerJoin) :: Nil
+        PushPredicateThroughJoin) :: Nil
   }
 
   val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
@@ -161,6 +164,184 @@ class FilterPushdownSuite extends OptimizerTest {
 
     comparePlans(optimized, correctAnswer)
   }
+  
+  test("joins: push down left outer join #1") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+
+    val originalQuery = {
+      x.join(y, LeftOuter)
+        .where("x.b".attr === 1 && "y.b".attr === 2)
+    }
+
+    val optimized = Optimize(originalQuery.analyze)
+    val left = testRelation.where('b === 1)
+    val correctAnswer =
+      left.join(y, LeftOuter).where("y.b".attr === 2).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("joins: push down right outer join #1") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+
+    val originalQuery = {
+      x.join(y, RightOuter)
+        .where("x.b".attr === 1 && "y.b".attr === 2)
+    }
+
+    val optimized = Optimize(originalQuery.analyze)
+    val right = testRelation.where('b === 2).subquery('d)
+    val correctAnswer =
+      x.join(right, RightOuter).where("x.b".attr === 1).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("joins: push down left outer join #2") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+
+    val originalQuery = {
+      x.join(y, LeftOuter, Some("x.b".attr === 1))
+        .where("x.b".attr === 2 && "y.b".attr === 2)
+    }
+
+    val optimized = Optimize(originalQuery.analyze)
+    val left = testRelation.where('b === 2).subquery('d)
+    val correctAnswer =
+      left.join(y, LeftOuter, Some("d.b".attr === 1)).where("y.b".attr === 2).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("joins: push down right outer join #2") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+
+    val originalQuery = {
+      x.join(y, RightOuter, Some("y.b".attr === 1))
+        .where("x.b".attr === 2 && "y.b".attr === 2)
+    }
+
+    val optimized = Optimize(originalQuery.analyze)
+    val right = testRelation.where('b === 2).subquery('d)
+    val correctAnswer =
+      x.join(right, RightOuter, Some("d.b".attr === 1)).where("x.b".attr === 2).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("joins: push down left outer join #3") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+
+    val originalQuery = {
+      x.join(y, LeftOuter, Some("y.b".attr === 1))
+        .where("x.b".attr === 2 && "y.b".attr === 2)
+    }
+
+    val optimized = Optimize(originalQuery.analyze)
+    val left = testRelation.where('b === 2).subquery('l)
+    val right = testRelation.where('b === 1).subquery('r)
+    val correctAnswer =
+      left.join(right, LeftOuter).where("r.b".attr === 2).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("joins: push down right outer join #3") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+
+    val originalQuery = {
+      x.join(y, RightOuter, Some("y.b".attr === 1))
+        .where("x.b".attr === 2 && "y.b".attr === 2)
+    }
+
+    val optimized = Optimize(originalQuery.analyze)
+    val right = testRelation.where('b === 2).subquery('r)
+    val correctAnswer =
+      x.join(right, RightOuter, Some("r.b".attr === 1)).where("x.b".attr === 2).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("joins: push down left outer join #4") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+
+    val originalQuery = {
+      x.join(y, LeftOuter, Some("y.b".attr === 1))
+        .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
+    }
+
+    val optimized = Optimize(originalQuery.analyze)
+    val left = testRelation.where('b === 2).subquery('l)
+    val right = testRelation.where('b === 1).subquery('r)
+    val correctAnswer =
+      left.join(right, LeftOuter).where("r.b".attr === 2 && "l.c".attr === "r.c".attr).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("joins: push down right outer join #4") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+
+    val originalQuery = {
+      x.join(y, RightOuter, Some("y.b".attr === 1))
+        .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
+    }
+
+    val optimized = Optimize(originalQuery.analyze)
+    val left = testRelation.subquery('l)
+    val right = testRelation.where('b === 2).subquery('r)
+    val correctAnswer =
+      left.join(right, RightOuter, Some("r.b".attr === 1)).
+        where("l.b".attr === 2 && "l.c".attr === "r.c".attr).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("joins: push down left outer join #5") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+
+    val originalQuery = {
+      x.join(y, LeftOuter, Some("y.b".attr === 1 && "x.a".attr === 3))
+        .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
+    }
+
+    val optimized = Optimize(originalQuery.analyze)
+    val left = testRelation.where('b === 2).subquery('l)
+    val right = testRelation.where('b === 1).subquery('r)
+    val correctAnswer =
+      left.join(right, LeftOuter, Some("l.a".attr===3)).
+        where("r.b".attr === 2 && "l.c".attr === "r.c".attr).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("joins: push down right outer join #5") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+
+    val originalQuery = {
+      x.join(y, RightOuter, Some("y.b".attr === 1 && "x.a".attr === 3))
+        .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
+    }
+
+    val optimized = Optimize(originalQuery.analyze)
+    val left = testRelation.where('a === 3).subquery('l)
+    val right = testRelation.where('b === 2).subquery('r)
+    val correctAnswer =
+      left.join(right, RightOuter, Some("r.b".attr === 1)).
+        where("l.b".attr === 2 && "l.c".attr === "r.c".attr).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
 
   test("joins: can't push down") {
     val x = testRelation.subquery('x)
-- 
GitLab