From b60b8137992641b9193e57061aa405f908b0f267 Mon Sep 17 00:00:00 2001
From: Wenchen Fan <wenchen@databricks.com>
Date: Wed, 2 Mar 2016 20:18:57 -0800
Subject: [PATCH] [SPARK-13617][SQL] remove unnecessary GroupingAnalytics trait

## What changes were proposed in this pull request?

The `trait GroupingAnalytics` only has one implementation, it's an unnecessary abstraction. This PR removes it, and does some code simplification when resolving `GroupingSet`.

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenchen@databricks.com>

Closes #11469 from cloud-fan/groupingset.
---
 .../sql/catalyst/analysis/Analyzer.scala      | 22 +++++++++---------
 .../plans/logical/basicOperators.scala        | 23 +++++--------------
 2 files changed, 17 insertions(+), 28 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 876aa0eae0..36eb59ef5e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -181,8 +181,8 @@ class Analyzer(
       case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
         Aggregate(groups, assignAliases(aggs), child)
 
-      case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
-        g.withNewAggs(assignAliases(g.aggregations))
+      case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
+        g.copy(aggregations = assignAliases(g.aggregations))
 
       case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child)
         if child.resolved && hasUnresolvedAlias(groupByExprs) =>
@@ -250,13 +250,9 @@ class Analyzer(
 
         val nonNullBitmask = x.bitmasks.reduce(_ & _)
 
-        val attributeMap = groupByAliases.zipWithIndex.map { case (a, idx) =>
-          if ((nonNullBitmask & 1 << idx) == 0) {
-            (a -> a.toAttribute.withNullability(true))
-          } else {
-            (a -> a.toAttribute)
-          }
-        }.toMap
+        val groupByAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
+          a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0)
+        }
 
         val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr =>
           // collect all the found AggregateExpression, so we can check an expression is part of
@@ -292,12 +288,16 @@ class Analyzer(
                   s"in grouping columns ${x.groupByExprs.mkString(",")}")
               }
             case e =>
-              groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e)
+              val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
+              if (index == -1) {
+                e
+              } else {
+                groupByAttributes(index)
+              }
           }.asInstanceOf[NamedExpression]
         }
 
         val child = Project(x.child.output ++ groupByAliases, x.child)
-        val groupByAttributes = groupByAliases.map(attributeMap(_))
 
         Aggregate(
           groupByAttributes :+ VirtualColumn.groupingIdAttribute,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index e81a0f9487..522348735a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -533,20 +533,6 @@ case class Expand(
   }
 }
 
-trait GroupingAnalytics extends UnaryNode {
-
-  def groupByExprs: Seq[Expression]
-  def aggregations: Seq[NamedExpression]
-
-  override def output: Seq[Attribute] = aggregations.map(_.toAttribute)
-
-  // Needs to be unresolved before its translated to Aggregate + Expand because output attributes
-  // will change in analysis.
-  override lazy val resolved: Boolean = false
-
-  def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics
-}
-
 /**
  * A GROUP BY clause with GROUPING SETS can generate a result set equivalent
  * to generated by a UNION ALL of multiple simple GROUP BY clauses.
@@ -565,10 +551,13 @@ case class GroupingSets(
     bitmasks: Seq[Int],
     groupByExprs: Seq[Expression],
     child: LogicalPlan,
-    aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
+    aggregations: Seq[NamedExpression]) extends UnaryNode {
+
+  override def output: Seq[Attribute] = aggregations.map(_.toAttribute)
 
-  def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
-    this.copy(aggregations = aggs)
+  // Needs to be unresolved before its translated to Aggregate + Expand because output attributes
+  // will change in analysis.
+  override lazy val resolved: Boolean = false
 }
 
 case class Pivot(
-- 
GitLab