From 4f769b903bc9822c262f0a15f5933cc05c67923f Mon Sep 17 00:00:00 2001
From: Herman van Hovell <hvanhovell@databricks.com>
Date: Wed, 7 Sep 2016 00:44:07 +0200
Subject: [PATCH] [SPARK-17296][SQL] Simplify parser join processing.

## What changes were proposed in this pull request?
Join processing in the parser relies on the fact that the grammar produces a right nested trees, for instance the parse tree for `select * from a join b join c` is expected to produce a tree similar to `JOIN(a, JOIN(b, c))`. However there are cases in which this (invariant) is violated, like:
```sql
SELECT COUNT(1)
FROM test T1
     CROSS JOIN test T2
     JOIN test T3
      ON T3.col = T1.col
     JOIN test T4
      ON T4.col = T1.col
```
In this case the parser returns a tree in which Joins are located on both the left and the right sides of the parent join node.

This PR introduces a different grammar rule which does not make this assumption. The new rule takes a relation and searches for zero or more joined relations. As a bonus processing is much easier.

## How was this patch tested?
Existing tests and I have added a regression test to the plan parser suite.

Author: Herman van Hovell <hvanhovell@databricks.com>

Closes #14867 from hvanhovell/SPARK-17296.
---
 .../spark/sql/catalyst/parser/SqlBase.g4      | 11 ++-
 .../sql/catalyst/parser/AstBuilder.scala      | 99 ++++++++++---------
 .../sql/catalyst/parser/ParserUtils.scala     |  6 +-
 .../sql/catalyst/parser/PlanParserSuite.scala | 44 +++++++++
 4 files changed, 102 insertions(+), 58 deletions(-)

diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 0447436ea7..9a643465a9 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -374,11 +374,12 @@ setQuantifier
     ;
 
 relation
-    : left=relation
-      (joinType JOIN right=relation joinCriteria?
-      | NATURAL joinType JOIN right=relation
-      )                                           #joinRelation
-    | relationPrimary                             #relationDefault
+    : relationPrimary joinRelation*
+    ;
+
+joinRelation
+    : (joinType) JOIN right=relationPrimary joinCriteria?
+    | NATURAL joinType JOIN right=relationPrimary
     ;
 
 joinType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index e4cb9f0161..bbbb14df88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -92,10 +92,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
 
     // Apply CTEs
     query.optional(ctx.ctes) {
-      val ctes = ctx.ctes.namedQuery.asScala.map {
-        case nCtx =>
-          val namedQuery = visitNamedQuery(nCtx)
-          (namedQuery.alias, namedQuery)
+      val ctes = ctx.ctes.namedQuery.asScala.map { nCtx =>
+        val namedQuery = visitNamedQuery(nCtx)
+        (namedQuery.alias, namedQuery)
       }
       // Check for duplicate names.
       checkDuplicateKeys(ctes, ctx)
@@ -401,7 +400,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
    * separated) relations here, these get converted into a single plan by condition-less inner join.
    */
   override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
-    val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None))
+    val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) =>
+      val right = plan(relation.relationPrimary)
+      val join = right.optionalMap(left)(Join(_, _, Inner, None))
+      withJoinRelations(join, relation)
+    }
     ctx.lateralView.asScala.foldLeft(from)(withGenerate)
   }
 
@@ -532,55 +535,53 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
   }
 
   /**
-   * Create a joins between two or more logical plans.
+   * Create a single relation referenced in a FROM claused. This method is used when a part of the
+   * join condition is nested, for example:
+   * {{{
+   *   select * from t1 join (t2 cross join t3) on col1 = col2
+   * }}}
    */
-  override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) {
-    /** Build a join between two plans. */
-    def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = {
-      val baseJoinType = ctx.joinType match {
-        case null => Inner
-        case jt if jt.CROSS != null => Cross
-        case jt if jt.FULL != null => FullOuter
-        case jt if jt.SEMI != null => LeftSemi
-        case jt if jt.ANTI != null => LeftAnti
-        case jt if jt.LEFT != null => LeftOuter
-        case jt if jt.RIGHT != null => RightOuter
-        case _ => Inner
-      }
+  override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) {
+    withJoinRelations(plan(ctx.relationPrimary), ctx)
+  }
 
-      // Resolve the join type and join condition
-      val (joinType, condition) = Option(ctx.joinCriteria) match {
-        case Some(c) if c.USING != null =>
-          val columns = c.identifier.asScala.map { column =>
-            UnresolvedAttribute.quoted(column.getText)
-          }
-          (UsingJoin(baseJoinType, columns), None)
-        case Some(c) if c.booleanExpression != null =>
-          (baseJoinType, Option(expression(c.booleanExpression)))
-        case None if ctx.NATURAL != null =>
-          (NaturalJoin(baseJoinType), None)
-        case None =>
-          (baseJoinType, None)
-      }
-      Join(left, right, joinType, condition)
-    }
+  /**
+   * Join one more [[LogicalPlan]]s to the current logical plan.
+   */
+  private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = {
+    ctx.joinRelation.asScala.foldLeft(base) { (left, join) =>
+      withOrigin(join) {
+        val baseJoinType = join.joinType match {
+          case null => Inner
+          case jt if jt.CROSS != null => Cross
+          case jt if jt.FULL != null => FullOuter
+          case jt if jt.SEMI != null => LeftSemi
+          case jt if jt.ANTI != null => LeftAnti
+          case jt if jt.LEFT != null => LeftOuter
+          case jt if jt.RIGHT != null => RightOuter
+          case _ => Inner
+        }
 
-    // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the
-    // first join clause is at the top. However fields of previously referenced tables can be used
-    // in following join clauses. The tree needs to be reversed in order to make this work.
-    var result = plan(ctx.left)
-    var current = ctx
-    while (current != null) {
-      current.right match {
-        case right: JoinRelationContext =>
-          result = join(current, result, plan(right.left))
-          current = right
-        case right =>
-          result = join(current, result, plan(right))
-          current = null
+        // Resolve the join type and join condition
+        val (joinType, condition) = Option(join.joinCriteria) match {
+          case Some(c) if c.USING != null =>
+            val columns = c.identifier.asScala.map { column =>
+              UnresolvedAttribute.quoted(column.getText)
+            }
+            (UsingJoin(baseJoinType, columns), None)
+          case Some(c) if c.booleanExpression != null =>
+            (baseJoinType, Option(expression(c.booleanExpression)))
+          case None if join.NATURAL != null =>
+            if (baseJoinType == Cross) {
+              throw new ParseException("NATURAL CROSS JOIN is not supported", ctx)
+            }
+            (NaturalJoin(baseJoinType), None)
+          case None =>
+            (baseJoinType, None)
+        }
+        Join(left, plan(join.right), joinType, condition)
       }
     }
-    result
   }
 
   /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index cb89a9679a..6fbc33fad7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser
 
 import scala.collection.mutable.StringBuilder
 
-import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token}
+import org.antlr.v4.runtime.{ParserRuleContext, Token}
 import org.antlr.v4.runtime.misc.Interval
 import org.antlr.v4.runtime.tree.TerminalNode
 
@@ -189,9 +189,7 @@ object ParserUtils {
      * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the
      * passed function. The original plan is returned when the context does not exist.
      */
-    def optionalMap[C <: ParserRuleContext](
-        ctx: C)(
-        f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
+    def optionalMap[C](ctx: C)(f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
       if (ctx != null) {
         f(ctx, plan)
       } else {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index faaea17b64..ca86304d4d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -360,10 +360,54 @@ class PlanParserSuite extends PlanTest {
     test("left anti join", LeftAnti, testExistence)
     test("anti join", LeftAnti, testExistence)
 
+    // Test natural cross join
+    intercept("select * from a natural cross join b")
+
+    // Test natural join with a condition
+    intercept("select * from a natural join b on a.id = b.id")
+
     // Test multiple consecutive joins
     assertEqual(
       "select * from a join b join c right join d",
       table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star()))
+
+    // SPARK-17296
+    assertEqual(
+      "select * from t1 cross join t2 join t3 on t3.id = t1.id join t4 on t4.id = t1.id",
+      table("t1")
+        .join(table("t2"), Cross)
+        .join(table("t3"), Inner, Option(Symbol("t3.id") === Symbol("t1.id")))
+        .join(table("t4"), Inner, Option(Symbol("t4.id") === Symbol("t1.id")))
+        .select(star()))
+
+    // Test multiple on clauses.
+    intercept("select * from t1 inner join t2 inner join t3 on col3 = col2 on col3 = col1")
+
+    // Parenthesis
+    assertEqual(
+      "select * from t1 inner join (t2 inner join t3 on col3 = col2) on col3 = col1",
+      table("t1")
+        .join(table("t2")
+          .join(table("t3"), Inner, Option('col3 === 'col2)), Inner, Option('col3 === 'col1))
+        .select(star()))
+    assertEqual(
+      "select * from t1 inner join (t2 inner join t3) on col3 = col2",
+      table("t1")
+        .join(table("t2").join(table("t3"), Inner, None), Inner, Option('col3 === 'col2))
+        .select(star()))
+    assertEqual(
+      "select * from t1 inner join (t2 inner join t3 on col3 = col2)",
+      table("t1")
+        .join(table("t2").join(table("t3"), Inner, Option('col3 === 'col2)), Inner, None)
+        .select(star()))
+
+    // Implicit joins.
+    assertEqual(
+      "select * from t1, t3 join t2 on t1.col1 = t2.col2",
+      table("t1")
+        .join(table("t3"))
+        .join(table("t2"), Inner, Option(Symbol("t1.col1") === Symbol("t2.col2")))
+        .select(star()))
   }
 
   test("sampled relations") {
-- 
GitLab