Skip to content
Snippets Groups Projects
Commit 05f652d6 authored by gatorsmile's avatar gatorsmile Committed by Wenchen Fan
Browse files

[SPARK-13957][SQL] Support Group By Ordinal in SQL

#### What changes were proposed in this pull request?
This PR is to support group by position in SQL. For example, when users input the following query
```SQL
select c1 as a, c2, c3, sum(*) from tbl group by 1, 3, c4
```
The ordinals are recognized as the positions in the select list. Thus, `Analyzer` converts it to
```SQL
select c1, c2, c3, sum(*) from tbl group by c1, c3, c4
```

This is controlled by the config option `spark.sql.groupByOrdinal`.
- When true, the ordinal numbers in group by clauses are treated as the position in the select list.
- When false, the ordinal numbers are ignored.
- Only convert integer literals (not foldable expressions). If found foldable expressions, ignore them.
- When the positions specified in the group by clauses correspond to the aggregate functions in select list, output an exception message.
- star is not allowed to use in the select list when users specify ordinals in group by

Note: This PR is taken from https://github.com/apache/spark/pull/10731. When merging this PR, please give the credit to zhichao-li

Also cc all the people who are involved in the previous discussion:  rxin cloud-fan marmbrus yhuai hvanhovell adrian-wang chenghao-intel tejasapatil

#### How was this patch tested?

Added a few test cases for both positive and negative test cases.

Author: gatorsmile <gatorsmile@gmail.com>
Author: xiaoli <lixiao1983@gmail.com>
Author: Xiao Li <xiaoli@Xiaos-MacBook-Pro.local>

Closes #11846 from gatorsmile/groupByOrdinal.
parent 0874ff3a
No related branches found
No related tags found
No related merge requests found
......@@ -23,6 +23,7 @@ private[spark] trait CatalystConf {
def caseSensitiveAnalysis: Boolean
def orderByOrdinal: Boolean
def groupByOrdinal: Boolean
/**
* Returns the [[Resolver]] for the current configuration, which can be used to determin if two
......@@ -48,11 +49,16 @@ object EmptyConf extends CatalystConf {
override def orderByOrdinal: Boolean = {
throw new UnsupportedOperationException
}
override def groupByOrdinal: Boolean = {
throw new UnsupportedOperationException
}
}
/** A CatalystConf that can be used for local testing. */
case class SimpleCatalystConf(
caseSensitiveAnalysis: Boolean,
orderByOrdinal: Boolean = true)
orderByOrdinal: Boolean = true,
groupByOrdinal: Boolean = true)
extends CatalystConf {
}
......@@ -85,6 +85,7 @@ class Analyzer(
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveUpCast ::
ResolveOrdinalInOrderByAndGroupBy ::
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
......@@ -385,7 +386,13 @@ class Analyzer(
p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) {
failAnalysis(
"Group by position: star is not allowed to use in the select list " +
"when using ordinals in group by")
} else {
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
}
// If the script transformation input contains Stars, expand it.
case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
......@@ -634,21 +641,23 @@ class Analyzer(
}
}
/**
* In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
* clause. This rule detects such queries and adds the required attributes to the original
* projection, so that they will be available during sorting. Another projection is added to
* remove these attributes after sorting.
*
* This rule also resolves the position number in sort references. This support is introduced
* in Spark 2.0. Before Spark 2.0, the integers in Order By has no effect on output sorting.
* - When the sort references are not integer but foldable expressions, ignore them.
* - When spark.sql.orderByOrdinal is set to false, ignore the position numbers too.
*/
object ResolveSortReferences extends Rule[LogicalPlan] {
/**
* In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by
* clauses. This rule is to convert ordinal positions to the corresponding expressions in the
* select list. This support is introduced in Spark 2.0.
*
* - When the sort references or group by expressions are not integer but foldable expressions,
* just ignore them.
* - When spark.sql.orderByOrdinal/spark.sql.groupByOrdinal is set to false, ignore the position
* numbers too.
*
* Before the release of Spark 2.0, the literals in order/sort by and group by clauses
* have no effect on the results.
*/
object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case s: Sort if !s.child.resolved => s
// Replace the index with the related attribute for ORDER BY
case p if !p.childrenResolved => p
// Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list.
case s @ Sort(orders, global, child)
if conf.orderByOrdinal && orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) =>
......@@ -665,10 +674,41 @@ class Analyzer(
}
Sort(newOrders, global, child)
// Replace the index with the corresponding expression in aggregateExpressions. The index is
// a 1-base position of aggregateExpressions, which is output columns (select expression)
case a @ Aggregate(groups, aggs, child)
if conf.groupByOrdinal && aggs.forall(_.resolved) &&
groups.exists(IntegerIndex.unapply(_).nonEmpty) =>
val newGroups = groups.map {
case IntegerIndex(index) if index > 0 && index <= aggs.size =>
aggs(index - 1) match {
case e if ResolveAggregateFunctions.containsAggregate(e) =>
throw new UnresolvedException(a,
s"Group by position: the '$index'th column in the select contains an " +
s"aggregate function: ${e.sql}. Aggregate functions are not allowed in GROUP BY")
case o => o
}
case IntegerIndex(index) =>
throw new UnresolvedException(a,
s"Group by position: '$index' exceeds the size of the select list '${aggs.size}'.")
case o => o
}
Aggregate(newGroups, aggs, child)
}
}
/**
* In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
* clause. This rule detects such queries and adds the required attributes to the original
* projection, so that they will be available during sorting. Another projection is added to
* remove these attributes after sorting.
*/
object ResolveSortReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa
case s @ Sort(order, _, child) if !s.resolved =>
case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
try {
val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
......
......@@ -210,7 +210,8 @@ object Unions {
object IntegerIndex {
def unapply(a: Any): Option[Int] = a match {
case Literal(a: Int, IntegerType) => Some(a)
// When resolving ordinal in Sort, negative values are extracted for issuing error messages.
// When resolving ordinal in Sort and Group By, negative values are extracted
// for issuing error messages.
case UnaryMinus(IntegerLiteral(v)) => Some(-v)
case _ => None
}
......
......@@ -445,6 +445,11 @@ object SQLConf {
doc = "When true, the ordinal numbers are treated as the position in the select list. " +
"When false, the ordinal numbers in order/sort By clause are ignored.")
val GROUP_BY_ORDINAL = booleanConf("spark.sql.groupByOrdinal",
defaultValue = Some(true),
doc = "When true, the ordinal numbers in group by clauses are treated as the position " +
"in the select list. When false, the ordinal numbers are ignored.")
// The output committer class used by HadoopFsRelation. The specified class needs to be a
// subclass of org.apache.hadoop.mapreduce.OutputCommitter.
//
......@@ -668,6 +673,7 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)
override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
......
......@@ -23,6 +23,7 @@ import java.sql.Timestamp
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin}
import org.apache.spark.sql.functions._
......@@ -459,25 +460,103 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Seq(Row(1, 3), Row(2, 3), Row(3, 3)))
}
test("literal in agg grouping expressions") {
test("Group By Ordinal - basic") {
checkAnswer(
sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
checkAnswer(
sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
sql("SELECT a, sum(b) FROM testData2 GROUP BY 1"),
sql("SELECT a, sum(b) FROM testData2 GROUP BY a"))
// duplicate group-by columns
checkAnswer(
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"),
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
checkAnswer(
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY 1, 2"),
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
}
test("Group By Ordinal - non aggregate expressions") {
checkAnswer(
sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, 2"),
sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
checkAnswer(
sql("SELECT a, b + 2 as c, count(2) FROM testData2 GROUP BY a, 2"),
sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
}
test("Group By Ordinal - non-foldable constant expression") {
checkAnswer(
sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b, 1 + 0"),
sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b"))
checkAnswer(
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
}
test("Group By Ordinal - alias") {
checkAnswer(
sql("SELECT a, (b + 2) as c, count(2) FROM testData2 GROUP BY a, 2"),
sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
checkAnswer(
sql("SELECT a as b, b as a, sum(b) FROM testData2 GROUP BY 1, 2"),
sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b"))
}
test("Group By Ordinal - constants") {
checkAnswer(
sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"),
sql("SELECT 1, 2, sum(b) FROM testData2"))
}
test("Group By Ordinal - negative cases") {
intercept[UnresolvedException[Aggregate]] {
sql("SELECT a, b FROM testData2 GROUP BY -1")
}
intercept[UnresolvedException[Aggregate]] {
sql("SELECT a, b FROM testData2 GROUP BY 3")
}
var e = intercept[UnresolvedException[Aggregate]](
sql("SELECT SUM(a) FROM testData2 GROUP BY 1"))
assert(e.getMessage contains
"Invalid call to Group by position: the '1'th column in the select contains " +
"an aggregate function")
e = intercept[UnresolvedException[Aggregate]](
sql("SELECT SUM(a) + 1 FROM testData2 GROUP BY 1"))
assert(e.getMessage contains
"Invalid call to Group by position: the '1'th column in the select contains " +
"an aggregate function")
var ae = intercept[AnalysisException](
sql("SELECT a, rand(0), sum(b) FROM testData2 GROUP BY a, 2"))
assert(ae.getMessage contains
"nondeterministic expression rand(0) should not appear in grouping expression")
ae = intercept[AnalysisException](
sql("SELECT * FROM testData2 GROUP BY a, b, 1"))
assert(ae.getMessage contains
"Group by position: star is not allowed to use in the select list " +
"when using ordinals in group by")
}
test("Group By Ordinal: spark.sql.groupByOrdinal=false") {
withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") {
// If spark.sql.groupByOrdinal=false, ignore the position number.
intercept[AnalysisException] {
sql("SELECT a, sum(b) FROM testData2 GROUP BY 1")
}
// '*' is not allowed to use in the select list when users specify ordinals in group by
checkAnswer(
sql("SELECT * FROM testData2 GROUP BY a, b, 1"),
sql("SELECT * FROM testData2 GROUP BY a, b"))
}
}
test("aggregates with nulls") {
checkAnswer(
sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," +
......@@ -2174,7 +2253,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY 1 + 0 DESC, b ASC"),
sql("SELECT * FROM testData2 ORDER BY b ASC"))
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"),
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment