Skip to content
Snippets Groups Projects
Commit 66e7a8f1 authored by Wenchen Fan's avatar Wenchen Fan Committed by Herman van Hovell
Browse files

[SPARK-20409][SQL] fail early if aggregate function in GROUP BY

## What changes were proposed in this pull request?

It's illegal to have aggregate function in GROUP BY, and we should fail at analysis phase, if this happens.

## How was this patch tested?

new regression test

Author: Wenchen Fan <wenchen@databricks.com>

Closes #17704 from cloud-fan/minor.
parent 9e5dc82a
No related branches found
No related tags found
No related merge requests found
...@@ -727,7 +727,7 @@ class Analyzer( ...@@ -727,7 +727,7 @@ class Analyzer(
case p if !p.childrenResolved => p case p if !p.childrenResolved => p
// Replace the index with the related attribute for ORDER BY, // Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list. // which is a 1-base position of the projection list.
case s @ Sort(orders, global, child) case Sort(orders, global, child)
if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
val newOrders = orders map { val newOrders = orders map {
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering) => case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering) =>
...@@ -744,17 +744,11 @@ class Analyzer( ...@@ -744,17 +744,11 @@ class Analyzer(
// Replace the index with the corresponding expression in aggregateExpressions. The index is // Replace the index with the corresponding expression in aggregateExpressions. The index is
// a 1-base position of aggregateExpressions, which is output columns (select expression) // a 1-base position of aggregateExpressions, which is output columns (select expression)
case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
groups.exists(_.isInstanceOf[UnresolvedOrdinal]) => groups.exists(_.isInstanceOf[UnresolvedOrdinal]) =>
val newGroups = groups.map { val newGroups = groups.map {
case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
aggs(index - 1) match { aggs(index - 1)
case e if ResolveAggregateFunctions.containsAggregate(e) =>
ordinal.failAnalysis(
s"GROUP BY position $index is an aggregate function, and " +
"aggregate functions are not allowed in GROUP BY")
case o => o
}
case ordinal @ UnresolvedOrdinal(index) => case ordinal @ UnresolvedOrdinal(index) =>
ordinal.failAnalysis( ordinal.failAnalysis(
s"GROUP BY position $index is not in select list " + s"GROUP BY position $index is not in select list " +
......
...@@ -266,6 +266,11 @@ trait CheckAnalysis extends PredicateHelper { ...@@ -266,6 +266,11 @@ trait CheckAnalysis extends PredicateHelper {
} }
def checkValidGroupingExprs(expr: Expression): Unit = { def checkValidGroupingExprs(expr: Expression): Unit = {
if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) {
failAnalysis(
"aggregate functions are not allowed in GROUP BY, but found " + expr.sql)
}
// Check if the data type of expr is orderable. // Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) { if (!RowOrdering.isOrderable(expr.dataType)) {
failAnalysis( failAnalysis(
...@@ -283,8 +288,8 @@ trait CheckAnalysis extends PredicateHelper { ...@@ -283,8 +288,8 @@ trait CheckAnalysis extends PredicateHelper {
} }
} }
aggregateExprs.foreach(checkValidAggregateExpression)
groupingExprs.foreach(checkValidGroupingExprs) groupingExprs.foreach(checkValidGroupingExprs)
aggregateExprs.foreach(checkValidAggregateExpression)
case Sort(orders, _, _) => case Sort(orders, _, _) =>
orders.foreach { order => orders.foreach { order =>
......
...@@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3 ...@@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3
struct<> struct<>
-- !query 11 output -- !query 11 output
org.apache.spark.sql.AnalysisException org.apache.spark.sql.AnalysisException
GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 39 aggregate functions are not allowed in GROUP BY, but found sum(CAST(data.`b` AS BIGINT));
-- !query 12 -- !query 12
...@@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3 ...@@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3
struct<> struct<>
-- !query 12 output -- !query 12 output
org.apache.spark.sql.AnalysisException org.apache.spark.sql.AnalysisException
GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 43 aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS BIGINT)) + CAST(2 AS BIGINT));
-- !query 13 -- !query 13
......
...@@ -538,4 +538,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ...@@ -538,4 +538,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0)) Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0))
) )
} }
test("aggregate function in GROUP BY") {
val e = intercept[AnalysisException] {
testData.groupBy(sum($"key")).count()
}
assert(e.message.contains("aggregate functions are not allowed in GROUP BY"))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment