From 714811d0b5bcb5d47c39782ff74f898d276ecc59 Mon Sep 17 00:00:00 2001
From: Takeshi Yamamuro <yamamuro@apache.org>
Date: Tue, 9 May 2017 20:22:51 +0800
Subject: [PATCH] [SPARK-20311][SQL] Support aliases for table value functions

## What changes were proposed in this pull request?
This pr added parsing rules to support aliases in table value functions.

## How was this patch tested?
Added tests in `PlanParserSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #17666 from maropu/SPARK-20311.
---
 .../spark/sql/catalyst/parser/SqlBase.g4      | 20 ++++++++++++-----
 .../ResolveTableValuedFunctions.scala         | 22 ++++++++++++++++---
 .../sql/catalyst/analysis/unresolved.scala    | 10 +++++++--
 .../sql/catalyst/parser/AstBuilder.scala      | 17 ++++++++++----
 .../sql/catalyst/analysis/AnalysisSuite.scala | 14 +++++++++++-
 .../sql/catalyst/parser/PlanParserSuite.scala | 13 ++++++++++-
 6 files changed, 79 insertions(+), 17 deletions(-)

diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 14c511f670..41daf58a98 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -472,15 +472,23 @@ identifierComment
     ;
 
 relationPrimary
-    : tableIdentifier sample? (AS? strictIdentifier)?               #tableName
-    | '(' queryNoWith ')' sample? (AS? strictIdentifier)?           #aliasedQuery
-    | '(' relation ')' sample? (AS? strictIdentifier)?              #aliasedRelation
-    | inlineTable                                                   #inlineTableDefault2
-    | identifier '(' (expression (',' expression)*)? ')'            #tableValuedFunction
+    : tableIdentifier sample? (AS? strictIdentifier)?      #tableName
+    | '(' queryNoWith ')' sample? (AS? strictIdentifier)?  #aliasedQuery
+    | '(' relation ')' sample? (AS? strictIdentifier)?     #aliasedRelation
+    | inlineTable                                          #inlineTableDefault2
+    | functionTable                                        #tableValuedFunction
     ;
 
 inlineTable
-    : VALUES expression (',' expression)*  (AS? identifier identifierList?)?
+    : VALUES expression (',' expression)*  tableAlias
+    ;
+
+functionTable
+    : identifier '(' (expression (',' expression)*)? ')' tableAlias
+    ;
+
+tableAlias
+    : (AS? identifier identifierList?)?
     ;
 
 rowFormat
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
index de6de24350..dad1340571 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis
 
 import java.util.Locale
 
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range}
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.types.{DataType, IntegerType, LongType}
 
@@ -105,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
     case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
-      builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
+      val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
         case Some(tvf) =>
           val resolved = tvf.flatMap { case (argList, resolver) =>
             argList.implicitCast(u.functionArgs) match {
@@ -125,5 +125,21 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
         case _ =>
           u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
       }
+
+      // If alias names assigned, add `Project` with the aliases
+      if (u.outputNames.nonEmpty) {
+        val outputAttrs = resolvedFunc.output
+        // Checks if the number of the aliases is equal to expected one
+        if (u.outputNames.size != outputAttrs.size) {
+          u.failAnalysis(s"expected ${outputAttrs.size} columns but " +
+            s"found ${u.outputNames.size} columns")
+        }
+        val aliases = outputAttrs.zip(u.outputNames).map {
+          case (attr, name) => Alias(attr, name)()
+        }
+        Project(aliases, resolvedFunc)
+      } else {
+        resolvedFunc
+      }
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 262b894e2a..51bef6e20b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -66,10 +66,16 @@ case class UnresolvedInlineTable(
 /**
  * A table-valued function, e.g.
  * {{{
- *   select * from range(10);
+ *   select id from range(10);
+ *
+ *   // Assign alias names
+ *   select t.a from range(10) t(a);
  * }}}
  */
-case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression])
+case class UnresolvedTableValuedFunction(
+    functionName: String,
+    functionArgs: Seq[Expression],
+    outputNames: Seq[String])
   extends LeafNode {
 
   override def output: Seq[Attribute] = Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index d2a9b4a9a9..e03fe2ccb8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -687,7 +687,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
    */
   override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
       : LogicalPlan = withOrigin(ctx) {
-    UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression))
+    val func = ctx.functionTable
+    val aliases = if (func.tableAlias.identifierList != null) {
+      visitIdentifierList(func.tableAlias.identifierList)
+    } else {
+      Seq.empty
+    }
+
+    val tvf = UnresolvedTableValuedFunction(
+      func.identifier.getText, func.expression.asScala.map(expression), aliases)
+    tvf.optionalMap(func.tableAlias.identifier)(aliasPlan)
   }
 
   /**
@@ -705,14 +714,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
       }
     }
 
-    val aliases = if (ctx.identifierList != null) {
-      visitIdentifierList(ctx.identifierList)
+    val aliases = if (ctx.tableAlias.identifierList != null) {
+      visitIdentifierList(ctx.tableAlias.identifierList)
     } else {
       Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
     }
 
     val table = UnresolvedInlineTable(aliases, rows)
-    table.optionalMap(ctx.identifier)(aliasPlan)
+    table.optionalMap(ctx.tableAlias.identifier)(aliasPlan)
   }
 
   /**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 893bb1b74c..31047f6886 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.catalyst.plans.Cross
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types._
@@ -441,4 +440,17 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {
 
     checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation)
   }
+
+  test("SPARK-20311 range(N) as alias") {
+    def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = {
+      SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames))
+        .select(star())
+    }
+    assertAnalysisSuccess(rangeWithAliases(3 :: Nil, "a" :: Nil))
+    assertAnalysisSuccess(rangeWithAliases(1 :: 4 :: Nil, "b" :: Nil))
+    assertAnalysisSuccess(rangeWithAliases(2 :: 6 :: 2 :: Nil, "c" :: Nil))
+    assertAnalysisError(
+      rangeWithAliases(3 :: Nil, "a" :: "b" :: Nil),
+      Seq("expected 1 columns but found 2 columns"))
+  }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 411777d6e8..4c2476296c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -468,7 +468,18 @@ class PlanParserSuite extends PlanTest {
   test("table valued function") {
     assertEqual(
       "select * from range(2)",
-      UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star()))
+      UnresolvedTableValuedFunction("range", Literal(2) :: Nil, Seq.empty).select(star()))
+  }
+
+  test("SPARK-20311 range(N) as alias") {
+    assertEqual(
+      "select * from range(10) AS t",
+      SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(10) :: Nil, Seq.empty))
+        .select(star()))
+    assertEqual(
+      "select * from range(7) AS t(a)",
+      SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, "a" :: Nil))
+        .select(star()))
   }
 
   test("inline table") {
-- 
GitLab