diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 7023c0c8c493f0024cb559432161f835c83acca8..de2f9ee6bc7a2b83e122519fb200efbe8269a7dd 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -262,7 +262,7 @@ ctes ; namedQuery - : name=identifier AS? '(' queryNoWith ')' + : name=identifier AS? '(' query ')' ; tableProvider diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cc62d5e7c8826f04dc8327a78aeb449ff9e1e898..ae8869ff25f2d2792e0d1f97e475d369c737326c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -116,15 +116,14 @@ class Analyzer( ) /** - * Substitute child plan with cte definitions + * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - // TODO allow subquery to define CTE def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => - resolved :+ name -> ResolveRelations(substituteCTE(relation, resolved)) + resolved :+ name -> execute(substituteCTE(relation, resolved)) }) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 69d68fa6f92efa54f19e6068ad5446d4e8cfa9c8..12a70b7769ef616b484878047ef79b612a91b43c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -108,7 +108,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * This is only used for Common Table Expressions. */ override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { - SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith), None) + SubqueryAlias(ctx.name.getText, plan(ctx.query), None) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 52387b4b72a16b6eb645704ef67c016ef1bad2de..eab45050f7e6335ffe73fe545f46078f1cc71f1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -76,6 +76,31 @@ class SubquerySuite extends QueryTest with SharedSQLContext { ) } + test("define CTE in CTE subquery") { + checkAnswer( + sql( + """ + | with t2 as (with t1 as (select 1 as b, 2 as c) select b, c from t1) + | select a from (select 1 as a union all select 2 as a) t + | where a = (select max(b) from t2) + """.stripMargin), + Array(Row(1)) + ) + checkAnswer( + sql( + """ + | with t2 as (with t1 as (select 1 as b, 2 as c) select b, c from t1), + | t3 as ( + | with t4 as (select 1 as d, 3 as e) + | select * from t4 cross join t2 where t2.b = t4.d + | ) + | select a from (select 1 as a union all select 2 as a) + | where a = (select max(d) from t3) + """.stripMargin), + Array(Row(1)) + ) + } + test("uncorrelated scalar subquery in CTE") { checkAnswer( sql("with t2 as (select 1 as b, 2 as c) " +