diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2679e026bb00a00383350e4da688b111c9623938..805cad5cb953edc02fd0dedb941ffe642100ea86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -932,7 +932,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) joinType match { - case _: InnerLike | LeftExistence(_) => + case _: InnerLike | LeftSemi | ExistenceJoin(_) => // push down the single side only join filter for both sides sub queries val newLeft = leftJoinConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) @@ -949,14 +949,14 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And) Join(newLeft, newRight, RightOuter, newJoinCond) - case LeftOuter => + case LeftOuter | LeftAnti => // push down the right side only join filter for right sub query val newLeft = left val newRight = rightJoinConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) - Join(newLeft, newRight, LeftOuter, newJoinCond) + Join(newLeft, newRight, joinType, newJoinCond) case FullOuter => j case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") case UsingJoin(_, _) => sys.error("Untransformed Using join node") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 019f132d94cb259f30363f842f8413bdf077bba1..3e67282d687f5790c4af281f6b2d20fd3b30553e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -514,6 +514,39 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) } + test("joins: push down where clause into left anti join") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val originalQuery = + x.join(y, LeftAnti, Some("x.b".attr === "y.b".attr)) + .where("x.a".attr > 10) + .analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + x.where("x.a".attr > 10) + .join(y, LeftAnti, Some("x.b".attr === "y.b".attr)) + .analyze + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + } + + test("joins: only push down join conditions to the right of a left anti join") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val originalQuery = + x.join(y, + LeftAnti, + Some("x.b".attr === "y.b".attr && "y.a".attr > 10 && "x.a".attr > 10)).analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + x.join( + y.where("y.a".attr > 10), + LeftAnti, + Some("x.b".attr === "y.b".attr && "x.a".attr > 10)) + .analyze + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + } + + val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) test("generate: predicate referenced no generated column") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/anti-join.sql b/sql/core/src/test/resources/sql-tests/inputs/anti-join.sql new file mode 100644 index 0000000000000000000000000000000000000000..0346f57d609adf4345ec0a6315b83d933c54a4f1 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/anti-join.sql @@ -0,0 +1,7 @@ +-- SPARK-18597: Do not push down predicates to left hand side in an anti-join +CREATE OR REPLACE TEMPORARY VIEW tbl_a AS VALUES (1, 1), (2, 1), (3, 6) AS T(c1, c2); +CREATE OR REPLACE TEMPORARY VIEW tbl_b AS VALUES 1 AS T(c1); + +SELECT * +FROM tbl_a + LEFT ANTI JOIN tbl_b ON ((tbl_a.c1 = tbl_a.c2) IS NULL OR tbl_a.c1 = tbl_a.c2); diff --git a/sql/core/src/test/resources/sql-tests/results/anti-join.sql.out b/sql/core/src/test/resources/sql-tests/results/anti-join.sql.out new file mode 100644 index 0000000000000000000000000000000000000000..6f38c4d08bc5a9fcc3b3c9621fca216934b46e2b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/anti-join.sql.out @@ -0,0 +1,29 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 3 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW tbl_a AS VALUES (1, 1), (2, 1), (3, 6) AS T(c1, c2) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE OR REPLACE TEMPORARY VIEW tbl_b AS VALUES 1 AS T(c1) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * +FROM tbl_a + LEFT ANTI JOIN tbl_b ON ((tbl_a.c1 = tbl_a.c2) IS NULL OR tbl_a.c1 = tbl_a.c2) +-- !query 2 schema +struct<c1:int,c2:int> +-- !query 2 output +2 1 +3 6