Skip to content
Snippets Groups Projects
Commit d403562e authored by Herman van Hovell's avatar Herman van Hovell
Browse files

[SPARK-17114][SQL] Fix aggregates grouped by literals with empty input

## What changes were proposed in this pull request?
This PR fixes an issue with aggregates that have an empty input, and use a literals as their grouping keys. These aggregates are currently interpreted as aggregates **without** grouping keys, this triggers the ungrouped code path (which aways returns a single row).

This PR fixes the `RemoveLiteralFromGroupExpressions` optimizer rule, which changes the semantics of the Aggregate by eliminating all literal grouping keys.

## How was this patch tested?
Added tests to `SQLQueryTestSuite`.

Author: Herman van Hovell <hvanhovell@databricks.com>

Closes #15101 from hvanhovell/SPARK-17114-3.
parent 5b8f7377
No related branches found
No related tags found
No related merge requests found
......@@ -1098,9 +1098,16 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] {
*/
object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(grouping, _, _) =>
case a @ Aggregate(grouping, _, _) if grouping.nonEmpty =>
val newGrouping = grouping.filter(!_.foldable)
a.copy(groupingExpressions = newGrouping)
if (newGrouping.nonEmpty) {
a.copy(groupingExpressions = newGrouping)
} else {
// All grouping expressions are literals. We should not drop them all, because this can
// change the return semantics when the input of the Aggregate is empty (SPARK-17114). We
// instead replace this by single, easy to hash/sort, literal expression.
a.copy(groupingExpressions = Seq(Literal(0, IntegerType)))
}
}
}
......
......@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
class AggregateOptimizeSuite extends PlanTest {
val conf = new SimpleCatalystConf(caseSensitiveAnalysis = false)
val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
val analyzer = new Analyzer(catalog, conf)
......@@ -49,6 +49,14 @@ class AggregateOptimizeSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
test("do not remove all grouping expressions if they are all literals") {
val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
val optimized = Optimize.execute(analyzer.execute(query))
val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))
comparePlans(optimized, correctAnswer)
}
test("Remove aliased literals") {
val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
val optimized = Optimize.execute(analyzer.execute(query))
......
-- Temporary data.
create temporary view myview as values 128, 256 as v(int_col);
-- group by should produce all input rows,
select int_col, count(*) from myview group by int_col;
-- group by should produce a single row.
select 'foo', count(*) from myview group by 1;
-- group-by should not produce any rows (whole stage code generation).
select 'foo' from myview where int_col == 0 group by 1;
-- group-by should not produce any rows (hash aggregate).
select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1;
-- group-by should not produce any rows (sort aggregate).
select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1;
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 6
-- !query 0
create temporary view myview as values 128, 256 as v(int_col)
-- !query 0 schema
struct<>
-- !query 0 output
-- !query 1
select int_col, count(*) from myview group by int_col
-- !query 1 schema
struct<int_col:int,count(1):bigint>
-- !query 1 output
128 1
256 1
-- !query 2
select 'foo', count(*) from myview group by 1
-- !query 2 schema
struct<foo:string,count(1):bigint>
-- !query 2 output
foo 2
-- !query 3
select 'foo' from myview where int_col == 0 group by 1
-- !query 3 schema
struct<foo:string>
-- !query 3 output
-- !query 4
select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1
-- !query 4 schema
struct<foo:string,approx_count_distinct(int_col):bigint>
-- !query 4 output
-- !query 5
select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1
-- !query 5 schema
struct<foo:string,max(struct(int_col)):struct<int_col:int>>
-- !query 5 output
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment