From 37cff1b1a79cad11277612cb9bc8bc2365cf5ff2 Mon Sep 17 00:00:00 2001
From: Andrew Ray <ray.andrew@gmail.com>
Date: Thu, 19 Nov 2015 15:11:30 -0800
Subject: [PATCH] [SPARK-11275][SQL] Incorrect results when using rollup/cube

Fixes bug with grouping sets (including cube/rollup) where aggregates that included grouping expressions would return the wrong (null) result.

Also simplifies the analyzer rule a bit and leaves column pruning to the optimizer.

Added multiple unit tests to DataFrameAggregateSuite and verified it passes hive compatibility suite:
```
build/sbt -Phive -Dspark.hive.whitelist='groupby.*_grouping.*' 'test-only org.apache.spark.sql.hive.execution.HiveCompatibilitySuite'
```

This is an alternative to pr https://github.com/apache/spark/pull/9419 but I think its better as it simplifies the analyzer rule instead of adding another special case to it.

Author: Andrew Ray <ray.andrew@gmail.com>

Closes #9815 from aray/groupingset-agg-fix.
---
 .../sql/catalyst/analysis/Analyzer.scala      | 58 +++++++----------
 .../plans/logical/basicOperators.scala        |  4 ++
 .../spark/sql/DataFrameAggregateSuite.scala   | 62 +++++++++++++++++++
 3 files changed, 90 insertions(+), 34 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 84781cd57f..47962ebe6e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -213,45 +213,35 @@ class Analyzer(
         GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
       case x: GroupingSets =>
         val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
-        // We will insert another Projection if the GROUP BY keys contains the
-        // non-attribute expressions. And the top operators can references those
-        // expressions by its alias.
-        // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==>
-        //      SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a
-
-        // find all of the non-attribute expressions in the GROUP BY keys
-        val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]()
-
-        // The pair of (the original GROUP BY key, associated attribute)
-        val groupByExprPairs = x.groupByExprs.map(_ match {
-          case e: NamedExpression => (e, e.toAttribute)
-          case other => {
-            val alias = Alias(other, other.toString)()
-            nonAttributeGroupByExpressions += alias // add the non-attributes expression alias
-            (other, alias.toAttribute)
-          }
-        })
-
-        // substitute the non-attribute expressions for aggregations.
-        val aggregation = x.aggregations.map(expr => expr.transformDown {
-          case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e)
-        }.asInstanceOf[NamedExpression])
 
-        // substitute the group by expressions.
-        val newGroupByExprs = groupByExprPairs.map(_._2)
+        // Expand works by setting grouping expressions to null as determined by the bitmasks. To
+        // prevent these null values from being used in an aggregate instead of the original value
+        // we need to create new aliases for all group by expressions that will only be used for
+        // the intended purpose.
+        val groupByAliases: Seq[Alias] = x.groupByExprs.map {
+          case e: NamedExpression => Alias(e, e.name)()
+          case other => Alias(other, other.toString)()
+        }
 
-        val child = if (nonAttributeGroupByExpressions.length > 0) {
-          // insert additional projection if contains the
-          // non-attribute expressions in the GROUP BY keys
-          Project(x.child.output ++ nonAttributeGroupByExpressions, x.child)
-        } else {
-          x.child
+        val aggregations: Seq[NamedExpression] = x.aggregations.map {
+          // If an expression is an aggregate (contains a AggregateExpression) then we dont change
+          // it so that the aggregation is computed on the unmodified value of its argument
+          // expressions.
+          case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty => expr
+          // If not then its a grouping expression and we need to use the modified (with nulls from
+          // Expand) value of the expression.
+          case expr => expr.transformDown {
+            case e => groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e)
+          }.asInstanceOf[NamedExpression]
         }
 
+        val child = Project(x.child.output ++ groupByAliases, x.child)
+        val groupByAttributes = groupByAliases.map(_.toAttribute)
+
         Aggregate(
-          newGroupByExprs :+ VirtualColumn.groupingIdAttribute,
-          aggregation,
-          Expand(x.bitmasks, newGroupByExprs, gid, child))
+          groupByAttributes :+ VirtualColumn.groupingIdAttribute,
+          aggregations,
+          Expand(x.bitmasks, groupByAttributes, gid, child))
     }
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 45630a591d..0c444482c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -323,6 +323,10 @@ trait GroupingAnalytics extends UnaryNode {
 
   override def output: Seq[Attribute] = aggregations.map(_.toAttribute)
 
+  // Needs to be unresolved before its translated to Aggregate + Expand because output attributes
+  // will change in analysis.
+  override lazy val resolved: Boolean = false
+
   def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics
 }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 71adf2148a..9c42f65bb6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -60,6 +60,68 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
     )
   }
 
+  test("rollup") {
+    checkAnswer(
+      courseSales.rollup("course", "year").sum("earnings"),
+      Row("Java", 2012, 20000.0) ::
+        Row("Java", 2013, 30000.0) ::
+        Row("Java", null, 50000.0) ::
+        Row("dotNET", 2012, 15000.0) ::
+        Row("dotNET", 2013, 48000.0) ::
+        Row("dotNET", null, 63000.0) ::
+        Row(null, null, 113000.0) :: Nil
+    )
+  }
+
+  test("cube") {
+    checkAnswer(
+      courseSales.cube("course", "year").sum("earnings"),
+      Row("Java", 2012, 20000.0) ::
+        Row("Java", 2013, 30000.0) ::
+        Row("Java", null, 50000.0) ::
+        Row("dotNET", 2012, 15000.0) ::
+        Row("dotNET", 2013, 48000.0) ::
+        Row("dotNET", null, 63000.0) ::
+        Row(null, 2012, 35000.0) ::
+        Row(null, 2013, 78000.0) ::
+        Row(null, null, 113000.0) :: Nil
+    )
+  }
+
+  test("rollup overlapping columns") {
+    checkAnswer(
+      testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"),
+      Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1)
+        :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1)
+        :: Row(null, null, 3) :: Nil
+    )
+
+    checkAnswer(
+      testData2.rollup("a", "b").agg(sum("b")),
+      Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2)
+        :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3)
+        :: Row(null, null, 9) :: Nil
+    )
+  }
+
+  test("cube overlapping columns") {
+    checkAnswer(
+      testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
+      Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1)
+        :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1)
+        :: Row(null, 1, 3) :: Row(null, 2, 0)
+        :: Row(null, null, 3) :: Nil
+    )
+
+    checkAnswer(
+      testData2.cube("a", "b").agg(sum("b")),
+      Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2)
+        :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3)
+        :: Row(null, 1, 3) :: Row(null, 2, 6)
+        :: Row(null, null, 9) :: Nil
+    )
+  }
+
   test("spark.sql.retainGroupColumns config") {
     checkAnswer(
       testData2.groupBy("a").agg(sum($"b")),
-- 
GitLab