diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 876aa0eae0e904f850a5f45a023cbcae6479a430..36eb59ef5ef9c597ad0840a0917f7c7497d95f6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -181,8 +181,8 @@ class Analyzer( case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) - case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) => - g.withNewAggs(assignAliases(g.aggregations)) + case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) => + g.copy(aggregations = assignAliases(g.aggregations)) case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) if child.resolved && hasUnresolvedAlias(groupByExprs) => @@ -250,13 +250,9 @@ class Analyzer( val nonNullBitmask = x.bitmasks.reduce(_ & _) - val attributeMap = groupByAliases.zipWithIndex.map { case (a, idx) => - if ((nonNullBitmask & 1 << idx) == 0) { - (a -> a.toAttribute.withNullability(true)) - } else { - (a -> a.toAttribute) - } - }.toMap + val groupByAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => + a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0) + } val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => // collect all the found AggregateExpression, so we can check an expression is part of @@ -292,12 +288,16 @@ class Analyzer( s"in grouping columns ${x.groupByExprs.mkString(",")}") } case e => - groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e) + val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) + if (index == -1) { + e + } else { + groupByAttributes(index) + } }.asInstanceOf[NamedExpression] } val child = Project(x.child.output ++ groupByAliases, x.child) - val groupByAttributes = groupByAliases.map(attributeMap(_)) Aggregate( groupByAttributes :+ VirtualColumn.groupingIdAttribute, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e81a0f9487469251b357affc51a945be39bd2db8..522348735aadf89f771bb6efa87e6c03320e00b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -533,20 +533,6 @@ case class Expand( } } -trait GroupingAnalytics extends UnaryNode { - - def groupByExprs: Seq[Expression] - def aggregations: Seq[NamedExpression] - - override def output: Seq[Attribute] = aggregations.map(_.toAttribute) - - // Needs to be unresolved before its translated to Aggregate + Expand because output attributes - // will change in analysis. - override lazy val resolved: Boolean = false - - def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics -} - /** * A GROUP BY clause with GROUPING SETS can generate a result set equivalent * to generated by a UNION ALL of multiple simple GROUP BY clauses. @@ -565,10 +551,13 @@ case class GroupingSets( bitmasks: Seq[Int], groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression]) extends GroupingAnalytics { + aggregations: Seq[NamedExpression]) extends UnaryNode { + + override def output: Seq[Attribute] = aggregations.map(_.toAttribute) - def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = - this.copy(aggregations = aggs) + // Needs to be unresolved before its translated to Aggregate + Expand because output attributes + // will change in analysis. + override lazy val resolved: Boolean = false } case class Pivot(