From f2e22aebfe49cdfdf20f060305772971bcea9266 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh <viirya@gmail.com> Date: Wed, 6 Sep 2017 07:42:19 -0700 Subject: [PATCH] [SPARK-21835][SQL] RewritePredicateSubquery should not produce unresolved query plans ## What changes were proposed in this pull request? Correlated predicate subqueries are rewritten into `Join` by the rule `RewritePredicateSubquery` during optimization. It is possibly that the two sides of the `Join` have conflicting attributes. The query plans produced by `RewritePredicateSubquery` become unresolved and break structural integrity. We should check if there are conflicting attributes in the `Join` and de-duplicate them by adding a `Project`. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #19050 from viirya/SPARK-21835. --- .../sql/catalyst/optimizer/subquery.scala | 39 ++++++++++-- .../org/apache/spark/sql/SubquerySuite.scala | 63 +++++++++++++++++++ 2 files changed, 98 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 4386a10162..7ff891516d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -49,6 +49,33 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } + private def dedupJoin(joinPlan: Join): Join = joinPlan match { + // SPARK-21835: It is possibly that the two sides of the join have conflicting attributes, + // the produced join then becomes unresolved and break structural integrity. We should + // de-duplicate conflicting attributes. We don't use transformation here because we only + // care about the most top join converted from correlated predicate subquery. + case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti), joinCond) => + val duplicates = right.outputSet.intersect(left.outputSet) + if (duplicates.nonEmpty) { + val aliasMap = AttributeMap(duplicates.map { dup => + dup -> Alias(dup, dup.toString)() + }.toSeq) + val aliasedExpressions = right.output.map { ref => + aliasMap.getOrElse(ref, ref) + } + val newRight = Project(aliasedExpressions, right) + val newJoinCond = joinCond.map { condExpr => + condExpr transform { + case a: Attribute => aliasMap.getOrElse(a, a).toAttribute + } + } + Join(left, newRight, joinType, newJoinCond) + } else { + j + } + case _ => joinPlan + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Filter(condition, child) => val (withSubquery, withoutSubquery) = @@ -64,14 +91,17 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { withSubquery.foldLeft(newFilter) { case (p, Exists(sub, conditions, _)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) - Join(outerPlan, sub, LeftSemi, joinCond) + // Deduplicate conflicting attributes if any. + dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) case (p, Not(Exists(sub, conditions, _))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) - Join(outerPlan, sub, LeftAnti, joinCond) + // Deduplicate conflicting attributes if any. + dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond)) case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) => val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) - Join(outerPlan, sub, LeftSemi, joinCond) + // Deduplicate conflicting attributes if any. + dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive @@ -93,7 +123,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // will have the final conditions in the LEFT ANTI as // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And) - Join(outerPlan, sub, LeftAnti, Option(pairs)) + // Deduplicate conflicting attributes if any. + dedupJoin(Join(outerPlan, sub, LeftAnti, Option(pairs))) case (p, predicate) => val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) Project(p.output, Filter(newCond.get, inputPlan)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 274694b995..ee6905e999 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.test.SharedSQLContext class SubquerySuite extends QueryTest with SharedSQLContext { @@ -875,4 +876,66 @@ class SubquerySuite extends QueryTest with SharedSQLContext { assert(e.message.contains("cannot resolve '`a`' given input columns: [t.i, t.j]")) } } + + test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 1") { + withTable("t1") { + withTempPath { path => + Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath) + sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}'") + + val sqlText = + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin + val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan + val join = optimizedPlan.collectFirst { case j: Join => j }.get + assert(join.duplicateResolved) + assert(optimizedPlan.resolved) + } + } + } + + test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 2") { + withTable("t1", "t2", "t3") { + withTempPath { path => + val data = Seq((1, 1, 1), (2, 0, 2)) + + data.toDF("t1a", "t1b", "t1c").write.parquet(path.getCanonicalPath + "/t1") + data.toDF("t2a", "t2b", "t2c").write.parquet(path.getCanonicalPath + "/t2") + data.toDF("t3a", "t3b", "t3c").write.parquet(path.getCanonicalPath + "/t3") + + sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}/t1'") + sql(s"CREATE TABLE t2 USING parquet LOCATION '${path.toURI}/t2'") + sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}/t3'") + + val sqlText = + s""" + |SELECT * + |FROM (SELECT * + | FROM t2 + | WHERE t2c IN (SELECT t1c + | FROM t1 + | WHERE t1a = t2a) + | UNION + | SELECT * + | FROM t3 + | WHERE t3a IN (SELECT t2a + | FROM t2 + | UNION ALL + | SELECT t1a + | FROM t1 + | WHERE t1b > 0)) t4 + |WHERE t4.t2b IN (SELECT Min(t3b) + | FROM t3 + | WHERE t4.t2a = t3a) + """.stripMargin + val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan + val joinNodes = optimizedPlan.collect { case j: Join => j } + joinNodes.foreach(j => assert(j.duplicateResolved)) + assert(optimizedPlan.resolved) + } + } + } } -- GitLab