diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a2e276e8a20596ff480b07c7bb7b3d170909f8f1..a2a022c2476fb3bd6cfae396301ed006df82fb00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -22,17 +22,16 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogRelation, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.NewInstance import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification -import org.apache.spark.sql.catalyst.planning.IntegerIndex import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.trees.{TreeNodeRef} import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types._ @@ -84,7 +83,8 @@ class Analyzer( Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution, - EliminateUnions), + EliminateUnions, + new UnresolvedOrdinalSubstitution(conf)), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: @@ -545,7 +545,7 @@ class Analyzer( p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) // If the aggregate function argument contains Stars, expand it. case a: Aggregate if containsStar(a.aggregateExpressions) => - if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) { + if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) { failAnalysis( "Star (*) is not allowed in select list when GROUP BY ordinal position is used") } else { @@ -716,9 +716,9 @@ class Analyzer( // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. case s @ Sort(orders, global, child) - if conf.orderByOrdinal && orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) => + if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { - case s @ SortOrder(IntegerIndex(index), direction) => + case s @ SortOrder(UnresolvedOrdinal(index), direction) => if (index > 0 && index <= child.output.size) { SortOrder(child.output(index - 1), direction) } else { @@ -732,11 +732,10 @@ class Analyzer( // Replace the index with the corresponding expression in aggregateExpressions. The index is // a 1-base position of aggregateExpressions, which is output columns (select expression) - case a @ Aggregate(groups, aggs, child) - if conf.groupByOrdinal && aggs.forall(_.resolved) && - groups.exists(IntegerIndex.unapply(_).nonEmpty) => + case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && + groups.exists(_.isInstanceOf[UnresolvedOrdinal]) => val newGroups = groups.map { - case ordinal @ IntegerIndex(index) if index > 0 && index <= aggs.size => + case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => aggs(index - 1) match { case e if ResolveAggregateFunctions.containsAggregate(e) => ordinal.failAnalysis( @@ -744,7 +743,7 @@ class Analyzer( "aggregate functions are not allowed in GROUP BY") case o => o } - case ordinal @ IntegerIndex(index) => + case ordinal @ UnresolvedOrdinal(index) => ordinal.failAnalysis( s"GROUP BY position $index is not in select list " + s"(valid range is [1, ${aggs.size}])") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala new file mode 100644 index 0000000000000000000000000000000000000000..e21cd08af8b0d9dafb6940eb8c3813dcea3b5501 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} +import org.apache.spark.sql.catalyst.planning.IntegerIndex +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin + +/** + * Replaces ordinal in 'order by' or 'group by' with UnresolvedOrdinal expression. + */ +class UnresolvedOrdinalSubstitution(conf: CatalystConf) extends Rule[LogicalPlan] { + private def isIntegerLiteral(sorter: Expression) = IntegerIndex.unapply(sorter).nonEmpty + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case s @ Sort(orders, global, child) if conf.orderByOrdinal && + orders.exists(o => isIntegerLiteral(o.child)) => + val newOrders = orders.map { + case order @ SortOrder(ordinal @ IntegerIndex(index: Int), _) => + val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) + withOrigin(order.origin)(order.copy(child = newOrdinal)) + case other => other + } + withOrigin(s.origin)(s.copy(order = newOrders)) + case a @ Aggregate(groups, aggs, child) if conf.groupByOrdinal && + groups.exists(isIntegerLiteral(_)) => + val newGroups = groups.map { + case ordinal @ IntegerIndex(index) => + withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) + case other => other + } + withOrigin(a.origin)(a.copy(groupingExpressions = newGroups)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 609089a302c888474ecf014bf9eda73d2d2ba1c8..42e7aae0b6b056949a6649a76ee1ba66b2b7fd7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -370,3 +370,21 @@ case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpr override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false } + +/** + * Represents unresolved ordinal used in order by or group by. + * + * For example: + * {{{ + * select a from table order by 1 + * select a from table group by 1 + * }}} + * @param ordinal ordinal starts from 1, instead of 0 + */ +case class UnresolvedOrdinal(ordinal: Int) + extends LeafExpression with Unevaluable with NonSQLExpression { + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 102c78bd72111887e55954a9f91ceeccd514cd01..22e1c9be0573d9fa4e8b5fa5d1b43238d7f7c1d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..23995e96e1d2bcee009c431d329948c5cb2eb763 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation2 +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.SimpleCatalystConf + +class UnresolvedOrdinalSubstitutionSuite extends AnalysisTest { + + test("test rule UnresolvedOrdinalSubstitution, replaces ordinal in order by or group by") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true) + + // Expression OrderByOrdinal is unresolved. + assert(!UnresolvedOrdinal(0).resolved) + + // Tests order by ordinal, apply single rule. + val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc) + comparePlans( + new UnresolvedOrdinalSubstitution(conf).apply(plan), + testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc)) + + // Tests order by ordinal, do full analysis + checkAnalysis(plan, testRelation2.orderBy(a.asc, b.asc)) + + // order by ordinal can be turned off by config + comparePlans( + new UnresolvedOrdinalSubstitution(conf.copy(orderByOrdinal = false)).apply(plan), + testRelation2.orderBy(Literal(1).asc, Literal(2).asc)) + + + // Tests group by ordinal, apply single rule. + val plan2 = testRelation2.groupBy(Literal(1), Literal(2))('a, 'b) + comparePlans( + new UnresolvedOrdinalSubstitution(conf).apply(plan2), + testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))('a, 'b)) + + // Tests group by ordinal, do full analysis + checkAnalysis(plan2, testRelation2.groupBy(a, b)(a, b)) + + // group by ordinal can be turned off by config + comparePlans( + new UnresolvedOrdinalSubstitution(conf.copy(groupByOrdinal = false)).apply(plan2), + testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)) + } +} diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 36b469c61788c260bf723a77406b1ddc9b705e4f..9c8d851e36e9ba09f658153451459f19857bb9b8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -43,6 +43,12 @@ select a, rand(0), sum(b) from data group by a, 2; -- negative case: star select * from data group by a, b, 1; +-- group by ordinal followed by order by +select a, count(a) from (select 1 as a) tmp group by 1 order by 1; + +-- group by ordinal followed by having +select count(a), a from (select 1 as a) tmp group by 2 having a > 0; + -- turn of group by ordinal set spark.sql.groupByOrdinal=false; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index 2f10b7ebc6d32e13e8b46ffe64741610816dc300..9c3a145f3aaa77ebe7d295d6745a4dec88e17638 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 17 +-- Number of queries: 19 -- !query 0 @@ -153,16 +153,32 @@ Star (*) is not allowed in select list when GROUP BY ordinal position is used; -- !query 15 -set spark.sql.groupByOrdinal=false +select a, count(a) from (select 1 as a) tmp group by 1 order by 1 -- !query 15 schema -struct<key:string,value:string> +struct<a:int,count(a):bigint> -- !query 15 output -spark.sql.groupByOrdinal +1 1 -- !query 16 -select sum(b) from data group by -1 +select count(a), a from (select 1 as a) tmp group by 2 having a > 0 -- !query 16 schema -struct<sum(b):bigint> +struct<count(a):bigint,a:int> -- !query 16 output +1 1 + + +-- !query 17 +set spark.sql.groupByOrdinal=false +-- !query 17 schema +struct<key:string,value:string> +-- !query 17 output +spark.sql.groupByOrdinal + + +-- !query 18 +select sum(b) from data group by -1 +-- !query 18 schema +struct<sum(b):bigint> +-- !query 18 output 9