Skip to content
Snippets Groups Projects
Commit 37cff1b1 authored by Andrew Ray's avatar Andrew Ray Committed by Yin Huai
Browse files

[SPARK-11275][SQL] Incorrect results when using rollup/cube

Fixes bug with grouping sets (including cube/rollup) where aggregates that included grouping expressions would return the wrong (null) result.

Also simplifies the analyzer rule a bit and leaves column pruning to the optimizer.

Added multiple unit tests to DataFrameAggregateSuite and verified it passes hive compatibility suite:
```
build/sbt -Phive -Dspark.hive.whitelist='groupby.*_grouping.*' 'test-only org.apache.spark.sql.hive.execution.HiveCompatibilitySuite'
```

This is an alternative to pr https://github.com/apache/spark/pull/9419 but I think its better as it simplifies the analyzer rule instead of adding another special case to it.

Author: Andrew Ray <ray.andrew@gmail.com>

Closes #9815 from aray/groupingset-agg-fix.
parent 01403aa9
No related branches found
No related tags found
No related merge requests found
......@@ -213,45 +213,35 @@ class Analyzer(
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
case x: GroupingSets =>
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
// We will insert another Projection if the GROUP BY keys contains the
// non-attribute expressions. And the top operators can references those
// expressions by its alias.
// e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==>
// SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a
// find all of the non-attribute expressions in the GROUP BY keys
val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]()
// The pair of (the original GROUP BY key, associated attribute)
val groupByExprPairs = x.groupByExprs.map(_ match {
case e: NamedExpression => (e, e.toAttribute)
case other => {
val alias = Alias(other, other.toString)()
nonAttributeGroupByExpressions += alias // add the non-attributes expression alias
(other, alias.toAttribute)
}
})
// substitute the non-attribute expressions for aggregations.
val aggregation = x.aggregations.map(expr => expr.transformDown {
case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e)
}.asInstanceOf[NamedExpression])
// substitute the group by expressions.
val newGroupByExprs = groupByExprPairs.map(_._2)
// Expand works by setting grouping expressions to null as determined by the bitmasks. To
// prevent these null values from being used in an aggregate instead of the original value
// we need to create new aliases for all group by expressions that will only be used for
// the intended purpose.
val groupByAliases: Seq[Alias] = x.groupByExprs.map {
case e: NamedExpression => Alias(e, e.name)()
case other => Alias(other, other.toString)()
}
val child = if (nonAttributeGroupByExpressions.length > 0) {
// insert additional projection if contains the
// non-attribute expressions in the GROUP BY keys
Project(x.child.output ++ nonAttributeGroupByExpressions, x.child)
} else {
x.child
val aggregations: Seq[NamedExpression] = x.aggregations.map {
// If an expression is an aggregate (contains a AggregateExpression) then we dont change
// it so that the aggregation is computed on the unmodified value of its argument
// expressions.
case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty => expr
// If not then its a grouping expression and we need to use the modified (with nulls from
// Expand) value of the expression.
case expr => expr.transformDown {
case e => groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e)
}.asInstanceOf[NamedExpression]
}
val child = Project(x.child.output ++ groupByAliases, x.child)
val groupByAttributes = groupByAliases.map(_.toAttribute)
Aggregate(
newGroupByExprs :+ VirtualColumn.groupingIdAttribute,
aggregation,
Expand(x.bitmasks, newGroupByExprs, gid, child))
groupByAttributes :+ VirtualColumn.groupingIdAttribute,
aggregations,
Expand(x.bitmasks, groupByAttributes, gid, child))
}
}
......
......@@ -323,6 +323,10 @@ trait GroupingAnalytics extends UnaryNode {
override def output: Seq[Attribute] = aggregations.map(_.toAttribute)
// Needs to be unresolved before its translated to Aggregate + Expand because output attributes
// will change in analysis.
override lazy val resolved: Boolean = false
def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics
}
......
......@@ -60,6 +60,68 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}
test("rollup") {
checkAnswer(
courseSales.rollup("course", "year").sum("earnings"),
Row("Java", 2012, 20000.0) ::
Row("Java", 2013, 30000.0) ::
Row("Java", null, 50000.0) ::
Row("dotNET", 2012, 15000.0) ::
Row("dotNET", 2013, 48000.0) ::
Row("dotNET", null, 63000.0) ::
Row(null, null, 113000.0) :: Nil
)
}
test("cube") {
checkAnswer(
courseSales.cube("course", "year").sum("earnings"),
Row("Java", 2012, 20000.0) ::
Row("Java", 2013, 30000.0) ::
Row("Java", null, 50000.0) ::
Row("dotNET", 2012, 15000.0) ::
Row("dotNET", 2013, 48000.0) ::
Row("dotNET", null, 63000.0) ::
Row(null, 2012, 35000.0) ::
Row(null, 2013, 78000.0) ::
Row(null, null, 113000.0) :: Nil
)
}
test("rollup overlapping columns") {
checkAnswer(
testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"),
Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1)
:: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1)
:: Row(null, null, 3) :: Nil
)
checkAnswer(
testData2.rollup("a", "b").agg(sum("b")),
Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2)
:: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3)
:: Row(null, null, 9) :: Nil
)
}
test("cube overlapping columns") {
checkAnswer(
testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1)
:: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1)
:: Row(null, 1, 3) :: Row(null, 2, 0)
:: Row(null, null, 3) :: Nil
)
checkAnswer(
testData2.cube("a", "b").agg(sum("b")),
Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2)
:: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3)
:: Row(null, 1, 3) :: Row(null, 2, 6)
:: Row(null, null, 9) :: Nil
)
}
test("spark.sql.retainGroupColumns config") {
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment