From 7b65030e7a0af3a0bd09370fb069d659b36ff7f0 Mon Sep 17 00:00:00 2001
From: Sean Zhong <seanzhong@databricks.com>
Date: Tue, 16 Aug 2016 15:51:30 +0800
Subject: [PATCH] [SPARK-17034][SQL] adds expression UnresolvedOrdinal to
 represent the ordinals in GROUP BY or ORDER BY

## What changes were proposed in this pull request?

This PR adds expression `UnresolvedOrdinal` to represent the ordinal in GROUP BY or ORDER BY, and fixes the rules when resolving ordinals.

Ordinals in GROUP BY or ORDER BY like `1` in `order by 1` or `group by 1` should be considered as unresolved before analysis. But in current code, it uses `Literal` expression to store the ordinal. This is inappropriate as `Literal` itself is a resolved expression, it gives the user a wrong message that the ordinals has already been resolved.

### Before this change

Ordinal is stored as `Literal` expression

```
scala> sc.setLogLevel("TRACE")
scala> sql("select a from t group by 1 order by 1")
...
'Sort [1 ASC], true
 +- 'Aggregate [1], ['a]
     +- 'UnresolvedRelation `t
```

For query:

```
scala> Seq(1).toDF("a").createOrReplaceTempView("t")
scala> sql("select count(a), a from t group by 2 having a > 0").show
```

During analysis, the intermediate plan before applying rule `ResolveAggregateFunctions` is:

```
'Filter ('a > 0)
   +- Aggregate [2], [count(1) AS count(1)#83L, a#81]
        +- LocalRelation [value#7 AS a#9]
```

Before this PR, rule `ResolveAggregateFunctions` believes all expressions of `Aggregate` have already been resolved, and tries to resolve the expressions in `Filter` directly. But this is wrong, as ordinal `2` in Aggregate is not really resolved!

### After this change

Ordinals are stored as `UnresolvedOrdinal`.

```
scala> sc.setLogLevel("TRACE")
scala> sql("select a from t group by 1 order by 1")
...
'Sort [unresolvedordinal(1) ASC], true
 +- 'Aggregate [unresolvedordinal(1)], ['a]
      +- 'UnresolvedRelation `t`
```

## How was this patch tested?

Unit tests.

Author: Sean Zhong <seanzhong@databricks.com>

Closes #14616 from clockfly/spark-16955.
---
 .../sql/catalyst/analysis/Analyzer.scala      | 23 ++++---
 .../UnresolvedOrdinalSubstitution.scala       | 52 +++++++++++++++
 .../sql/catalyst/analysis/unresolved.scala    | 18 +++++
 .../sql/catalyst/analysis/AnalysisSuite.scala |  2 +-
 .../UnresolvedOrdinalSubstitutionSuite.scala  | 65 +++++++++++++++++++
 .../sql-tests/inputs/group-by-ordinal.sql     |  6 ++
 .../results/group-by-ordinal.sql.out          | 28 ++++++--
 7 files changed, 175 insertions(+), 19 deletions(-)
 create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala
 create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a2e276e8a2..a2a022c247 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -22,17 +22,16 @@ import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf}
-import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogRelation, InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
 import org.apache.spark.sql.catalyst.encoders.OuterScopes
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
 import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
-import org.apache.spark.sql.catalyst.planning.IntegerIndex
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _}
 import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.catalyst.trees.TreeNodeRef
+import org.apache.spark.sql.catalyst.trees.{TreeNodeRef}
 import org.apache.spark.sql.catalyst.util.toPrettySQL
 import org.apache.spark.sql.types._
 
@@ -84,7 +83,8 @@ class Analyzer(
     Batch("Substitution", fixedPoint,
       CTESubstitution,
       WindowsSubstitution,
-      EliminateUnions),
+      EliminateUnions,
+      new UnresolvedOrdinalSubstitution(conf)),
     Batch("Resolution", fixedPoint,
       ResolveRelations ::
       ResolveReferences ::
@@ -545,7 +545,7 @@ class Analyzer(
         p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
       // If the aggregate function argument contains Stars, expand it.
       case a: Aggregate if containsStar(a.aggregateExpressions) =>
-        if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) {
+        if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) {
           failAnalysis(
             "Star (*) is not allowed in select list when GROUP BY ordinal position is used")
         } else {
@@ -716,9 +716,9 @@ class Analyzer(
       // Replace the index with the related attribute for ORDER BY,
       // which is a 1-base position of the projection list.
       case s @ Sort(orders, global, child)
-          if conf.orderByOrdinal && orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) =>
+        if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
         val newOrders = orders map {
-          case s @ SortOrder(IntegerIndex(index), direction) =>
+          case s @ SortOrder(UnresolvedOrdinal(index), direction) =>
             if (index > 0 && index <= child.output.size) {
               SortOrder(child.output(index - 1), direction)
             } else {
@@ -732,11 +732,10 @@ class Analyzer(
 
       // Replace the index with the corresponding expression in aggregateExpressions. The index is
       // a 1-base position of aggregateExpressions, which is output columns (select expression)
-      case a @ Aggregate(groups, aggs, child)
-          if conf.groupByOrdinal && aggs.forall(_.resolved) &&
-            groups.exists(IntegerIndex.unapply(_).nonEmpty) =>
+      case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
+        groups.exists(_.isInstanceOf[UnresolvedOrdinal]) =>
         val newGroups = groups.map {
-          case ordinal @ IntegerIndex(index) if index > 0 && index <= aggs.size =>
+          case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
             aggs(index - 1) match {
               case e if ResolveAggregateFunctions.containsAggregate(e) =>
                 ordinal.failAnalysis(
@@ -744,7 +743,7 @@ class Analyzer(
                     "aggregate functions are not allowed in GROUP BY")
               case o => o
             }
-          case ordinal @ IntegerIndex(index) =>
+          case ordinal @ UnresolvedOrdinal(index) =>
             ordinal.failAnalysis(
               s"GROUP BY position $index is not in select list " +
                 s"(valid range is [1, ${aggs.size}])")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala
new file mode 100644
index 0000000000..e21cd08af8
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.CatalystConf
+import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder}
+import org.apache.spark.sql.catalyst.planning.IntegerIndex
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
+
+/**
+ * Replaces ordinal in 'order by' or 'group by' with UnresolvedOrdinal expression.
+ */
+class UnresolvedOrdinalSubstitution(conf: CatalystConf) extends Rule[LogicalPlan] {
+  private def isIntegerLiteral(sorter: Expression) = IntegerIndex.unapply(sorter).nonEmpty
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case s @ Sort(orders, global, child) if conf.orderByOrdinal &&
+      orders.exists(o => isIntegerLiteral(o.child)) =>
+      val newOrders = orders.map {
+        case order @ SortOrder(ordinal @ IntegerIndex(index: Int), _) =>
+          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
+          withOrigin(order.origin)(order.copy(child = newOrdinal))
+        case other => other
+      }
+      withOrigin(s.origin)(s.copy(order = newOrders))
+    case a @ Aggregate(groups, aggs, child) if conf.groupByOrdinal &&
+      groups.exists(isIntegerLiteral(_)) =>
+      val newGroups = groups.map {
+        case ordinal @ IntegerIndex(index) =>
+          withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
+        case other => other
+      }
+      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
+  }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 609089a302..42e7aae0b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -370,3 +370,21 @@ case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpr
   override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
   override lazy val resolved = false
 }
+
+/**
+ * Represents unresolved ordinal used in order by or group by.
+ *
+ * For example:
+ * {{{
+ *   select a from table order by 1
+ *   select a   from table group by 1
+ * }}}
+ * @param ordinal ordinal starts from 1, instead of 0
+ */
+case class UnresolvedOrdinal(ordinal: Int)
+    extends LeafExpression with Unevaluable with NonSQLExpression {
+  override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+  override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+  override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+  override lazy val resolved = false
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 102c78bd72..22e1c9be05 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
-import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier}
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala
new file mode 100644
index 0000000000..23995e96e1
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation2
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+
+class UnresolvedOrdinalSubstitutionSuite extends AnalysisTest {
+
+  test("test rule UnresolvedOrdinalSubstitution, replaces ordinal in order by or group by") {
+    val a = testRelation2.output(0)
+    val b = testRelation2.output(1)
+    val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true)
+
+    // Expression OrderByOrdinal is unresolved.
+    assert(!UnresolvedOrdinal(0).resolved)
+
+    // Tests order by ordinal, apply single rule.
+    val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc)
+    comparePlans(
+      new UnresolvedOrdinalSubstitution(conf).apply(plan),
+      testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc))
+
+    // Tests order by ordinal, do full analysis
+    checkAnalysis(plan, testRelation2.orderBy(a.asc, b.asc))
+
+    // order by ordinal can be turned off by config
+    comparePlans(
+      new UnresolvedOrdinalSubstitution(conf.copy(orderByOrdinal = false)).apply(plan),
+      testRelation2.orderBy(Literal(1).asc, Literal(2).asc))
+
+
+    // Tests group by ordinal, apply single rule.
+    val plan2 = testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)
+    comparePlans(
+      new UnresolvedOrdinalSubstitution(conf).apply(plan2),
+      testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))('a, 'b))
+
+    // Tests group by ordinal, do full analysis
+    checkAnalysis(plan2, testRelation2.groupBy(a, b)(a, b))
+
+    // group by ordinal can be turned off by config
+    comparePlans(
+      new UnresolvedOrdinalSubstitution(conf.copy(groupByOrdinal = false)).apply(plan2),
+      testRelation2.groupBy(Literal(1), Literal(2))('a, 'b))
+  }
+}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql
index 36b469c617..9c8d851e36 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql
@@ -43,6 +43,12 @@ select a, rand(0), sum(b) from data group by a, 2;
 -- negative case: star
 select * from data group by a, b, 1;
 
+-- group by ordinal followed by order by
+select a, count(a) from (select 1 as a) tmp group by 1 order by 1;
+
+-- group by ordinal followed by having
+select count(a), a from (select 1 as a) tmp group by 2 having a > 0;
+
 -- turn of group by ordinal
 set spark.sql.groupByOrdinal=false;
 
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
index 2f10b7ebc6..9c3a145f3a 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 17
+-- Number of queries: 19
 
 
 -- !query 0
@@ -153,16 +153,32 @@ Star (*) is not allowed in select list when GROUP BY ordinal position is used;
 
 
 -- !query 15
-set spark.sql.groupByOrdinal=false
+select a, count(a) from (select 1 as a) tmp group by 1 order by 1
 -- !query 15 schema
-struct<key:string,value:string>
+struct<a:int,count(a):bigint>
 -- !query 15 output
-spark.sql.groupByOrdinal
+1	1
 
 
 -- !query 16
-select sum(b) from data group by -1
+select count(a), a from (select 1 as a) tmp group by 2 having a > 0
 -- !query 16 schema
-struct<sum(b):bigint>
+struct<count(a):bigint,a:int>
 -- !query 16 output
+1	1
+
+
+-- !query 17
+set spark.sql.groupByOrdinal=false
+-- !query 17 schema
+struct<key:string,value:string>
+-- !query 17 output
+spark.sql.groupByOrdinal
+
+
+-- !query 18
+select sum(b) from data group by -1
+-- !query 18 schema
+struct<sum(b):bigint>
+-- !query 18 output
 9
-- 
GitLab