diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 0447436ea79765b2299c22f9bcf3142f4b89d997..9a643465a9994429ea69c47138d40431ea6fa161 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -374,11 +374,12 @@ setQuantifier ; relation - : left=relation - (joinType JOIN right=relation joinCriteria? - | NATURAL joinType JOIN right=relation - ) #joinRelation - | relationPrimary #relationDefault + : relationPrimary joinRelation* + ; + +joinRelation + : (joinType) JOIN right=relationPrimary joinCriteria? + | NATURAL joinType JOIN right=relationPrimary ; joinType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e4cb9f016133abf3ac3693dda6a0ca61b00776c8..bbbb14df88f8cc24c6d9204e205acca6de2f7789 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -92,10 +92,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // Apply CTEs query.optional(ctx.ctes) { - val ctes = ctx.ctes.namedQuery.asScala.map { - case nCtx => - val namedQuery = visitNamedQuery(nCtx) - (namedQuery.alias, namedQuery) + val ctes = ctx.ctes.namedQuery.asScala.map { nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) } // Check for duplicate names. checkDuplicateKeys(ctes, ctx) @@ -401,7 +400,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * separated) relations here, these get converted into a single plan by condition-less inner join. */ override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { - val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None)) + val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) => + val right = plan(relation.relationPrimary) + val join = right.optionalMap(left)(Join(_, _, Inner, None)) + withJoinRelations(join, relation) + } ctx.lateralView.asScala.foldLeft(from)(withGenerate) } @@ -532,55 +535,53 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a joins between two or more logical plans. + * Create a single relation referenced in a FROM claused. This method is used when a part of the + * join condition is nested, for example: + * {{{ + * select * from t1 join (t2 cross join t3) on col1 = col2 + * }}} */ - override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) { - /** Build a join between two plans. */ - def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = { - val baseJoinType = ctx.joinType match { - case null => Inner - case jt if jt.CROSS != null => Cross - case jt if jt.FULL != null => FullOuter - case jt if jt.SEMI != null => LeftSemi - case jt if jt.ANTI != null => LeftAnti - case jt if jt.LEFT != null => LeftOuter - case jt if jt.RIGHT != null => RightOuter - case _ => Inner - } + override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) { + withJoinRelations(plan(ctx.relationPrimary), ctx) + } - // Resolve the join type and join condition - val (joinType, condition) = Option(ctx.joinCriteria) match { - case Some(c) if c.USING != null => - val columns = c.identifier.asScala.map { column => - UnresolvedAttribute.quoted(column.getText) - } - (UsingJoin(baseJoinType, columns), None) - case Some(c) if c.booleanExpression != null => - (baseJoinType, Option(expression(c.booleanExpression))) - case None if ctx.NATURAL != null => - (NaturalJoin(baseJoinType), None) - case None => - (baseJoinType, None) - } - Join(left, right, joinType, condition) - } + /** + * Join one more [[LogicalPlan]]s to the current logical plan. + */ + private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = { + ctx.joinRelation.asScala.foldLeft(base) { (left, join) => + withOrigin(join) { + val baseJoinType = join.joinType match { + case null => Inner + case jt if jt.CROSS != null => Cross + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.ANTI != null => LeftAnti + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } - // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the - // first join clause is at the top. However fields of previously referenced tables can be used - // in following join clauses. The tree needs to be reversed in order to make this work. - var result = plan(ctx.left) - var current = ctx - while (current != null) { - current.right match { - case right: JoinRelationContext => - result = join(current, result, plan(right.left)) - current = right - case right => - result = join(current, result, plan(right)) - current = null + // Resolve the join type and join condition + val (joinType, condition) = Option(join.joinCriteria) match { + case Some(c) if c.USING != null => + val columns = c.identifier.asScala.map { column => + UnresolvedAttribute.quoted(column.getText) + } + (UsingJoin(baseJoinType, columns), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case None if join.NATURAL != null => + if (baseJoinType == Cross) { + throw new ParseException("NATURAL CROSS JOIN is not supported", ctx) + } + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + Join(left, plan(join.right), joinType, condition) } } - result } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index cb89a9679a8cf3ec919253639a2030506c97d1a5..6fbc33fad735c935430666586e5184267c12a1d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser import scala.collection.mutable.StringBuilder -import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token} +import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.misc.Interval import org.antlr.v4.runtime.tree.TerminalNode @@ -189,9 +189,7 @@ object ParserUtils { * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the * passed function. The original plan is returned when the context does not exist. */ - def optionalMap[C <: ParserRuleContext]( - ctx: C)( - f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { + def optionalMap[C](ctx: C)(f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { if (ctx != null) { f(ctx, plan) } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index faaea17b64d2a5ddf52522fe8396393bd55d2770..ca86304d4d400b44f52e1ba92ec8387ee60c7fd0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -360,10 +360,54 @@ class PlanParserSuite extends PlanTest { test("left anti join", LeftAnti, testExistence) test("anti join", LeftAnti, testExistence) + // Test natural cross join + intercept("select * from a natural cross join b") + + // Test natural join with a condition + intercept("select * from a natural join b on a.id = b.id") + // Test multiple consecutive joins assertEqual( "select * from a join b join c right join d", table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star())) + + // SPARK-17296 + assertEqual( + "select * from t1 cross join t2 join t3 on t3.id = t1.id join t4 on t4.id = t1.id", + table("t1") + .join(table("t2"), Cross) + .join(table("t3"), Inner, Option(Symbol("t3.id") === Symbol("t1.id"))) + .join(table("t4"), Inner, Option(Symbol("t4.id") === Symbol("t1.id"))) + .select(star())) + + // Test multiple on clauses. + intercept("select * from t1 inner join t2 inner join t3 on col3 = col2 on col3 = col1") + + // Parenthesis + assertEqual( + "select * from t1 inner join (t2 inner join t3 on col3 = col2) on col3 = col1", + table("t1") + .join(table("t2") + .join(table("t3"), Inner, Option('col3 === 'col2)), Inner, Option('col3 === 'col1)) + .select(star())) + assertEqual( + "select * from t1 inner join (t2 inner join t3) on col3 = col2", + table("t1") + .join(table("t2").join(table("t3"), Inner, None), Inner, Option('col3 === 'col2)) + .select(star())) + assertEqual( + "select * from t1 inner join (t2 inner join t3 on col3 = col2)", + table("t1") + .join(table("t2").join(table("t3"), Inner, Option('col3 === 'col2)), Inner, None) + .select(star())) + + // Implicit joins. + assertEqual( + "select * from t1, t3 join t2 on t1.col1 = t2.col2", + table("t1") + .join(table("t3")) + .join(table("t2"), Inner, Option(Symbol("t1.col1") === Symbol("t2.col2"))) + .select(star())) } test("sampled relations") {