diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8dbec408002f11fd72a1ad458a334696f698f8c9..dd68d60d3e8396db1efccfe05171e2b15c989e37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -217,11 +217,9 @@ class Analyzer( * Group Count: N + 1 (N is the number of group expressions) * * We need to get all of its subsets for the rule described above, the subset is - * represented as the bit masks. + * represented as sequence of expressions. */ - def bitmasks(r: Rollup): Seq[Int] = { - Seq.tabulate(r.groupByExprs.length + 1)(idx => (1 << idx) - 1) - } + def rollupExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.inits.toSeq /* * GROUP BY a, b, c WITH CUBE @@ -230,10 +228,14 @@ class Analyzer( * Group Count: 2 ^ N (N is the number of group expressions) * * We need to get all of its subsets for a given GROUPBY expression, the subsets are - * represented as the bit masks. + * represented as sequence of expressions. */ - def bitmasks(c: Cube): Seq[Int] = { - Seq.tabulate(1 << c.groupByExprs.length)(i => i) + def cubeExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.toList match { + case x :: xs => + val initial = cubeExprs(xs) + initial.map(x +: _) ++ initial + case Nil => + Seq(Seq.empty) } private def hasGroupingAttribute(expr: Expression): Boolean = { @@ -256,17 +258,17 @@ class Analyzer( expr transform { case e: GroupingID => if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) { - gid + Alias(gid, toPrettySQL(e))() } else { throw new AnalysisException( s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + s"grouping columns (${groupByExprs.mkString(",")})") } - case Grouping(col: Expression) => + case e @ Grouping(col: Expression) => val idx = groupByExprs.indexOf(col) if (idx >= 0) { - Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), - Literal(1)), ByteType) + Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), + Literal(1)), ByteType), toPrettySQL(e))() } else { throw new AnalysisException(s"Column of grouping ($col) can't be found " + s"in grouping columns ${groupByExprs.mkString(",")}") @@ -274,85 +276,107 @@ class Analyzer( } } - // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case a if !a.childrenResolved => a // be sure all of the children are resolved. - case p if p.expressions.exists(hasGroupingAttribute) => - failAnalysis( - s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") - - case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => - GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions) - case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => - GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions) + /* + * Create new alias for all group by expressions for `Expand` operator. + */ + private def constructGroupByAlias(groupByExprs: Seq[Expression]): Seq[Alias] = { + groupByExprs.map { + case e: NamedExpression => Alias(e, e.name)() + case other => Alias(other, other.toString)() + } + } - // Ensure all the expressions have been resolved. - case x: GroupingSets if x.expressions.forall(_.resolved) => - val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() - - // Expand works by setting grouping expressions to null as determined by the bitmasks. To - // prevent these null values from being used in an aggregate instead of the original value - // we need to create new aliases for all group by expressions that will only be used for - // the intended purpose. - val groupByAliases: Seq[Alias] = x.groupByExprs.map { - case e: NamedExpression => Alias(e, e.name)() - case other => Alias(other, other.toString)() + /* + * Construct [[Expand]] operator with grouping sets. + */ + private def constructExpand( + selectedGroupByExprs: Seq[Seq[Expression]], + child: LogicalPlan, + groupByAliases: Seq[Alias], + gid: Attribute): LogicalPlan = { + // Change the nullability of group by aliases if necessary. For example, if we have + // GROUPING SETS ((a,b), a), we do not need to change the nullability of a, but we + // should change the nullabilty of b to be TRUE. + // TODO: For Cube/Rollup just set nullability to be `true`. + val expandedAttributes = groupByAliases.map { alias => + if (selectedGroupByExprs.exists(!_.contains(alias.child))) { + alias.toAttribute.withNullability(true) + } else { + alias.toAttribute } + } - // The rightmost bit in the bitmasks corresponds to the last expression in groupByAliases - // with 0 indicating this expression is in the grouping set. The following line of code - // calculates the bitmask representing the expressions that absent in at least one grouping - // set (indicated by 1). - val nullBitmask = x.bitmasks.reduce(_ | _) - - val attrLength = groupByAliases.length - val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => - a.toAttribute.withNullability(((nullBitmask >> (attrLength - idx - 1)) & 1) == 1) + val groupingSetsAttributes = selectedGroupByExprs.map { groupingSetExprs => + groupingSetExprs.map { expr => + val alias = groupByAliases.find(_.child.semanticEquals(expr)).getOrElse( + failAnalysis(s"$expr doesn't show up in the GROUP BY list $groupByAliases")) + // Map alias to expanded attribute. + expandedAttributes.find(_.semanticEquals(alias.toAttribute)).getOrElse( + alias.toAttribute) } + } - val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child) - val groupingAttrs = expand.output.drop(x.child.output.length) + Expand(groupingSetsAttributes, groupByAliases, expandedAttributes, gid, child) + } - val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => - // collect all the found AggregateExpression, so we can check an expression is part of - // any AggregateExpression or not. - val aggsBuffer = ArrayBuffer[Expression]() - // Returns whether the expression belongs to any expressions in `aggsBuffer` or not. - def isPartOfAggregation(e: Expression): Boolean = { - aggsBuffer.exists(a => a.find(_ eq e).isDefined) + /* + * Construct new aggregate expressions by replacing grouping functions. + */ + private def constructAggregateExprs( + groupByExprs: Seq[Expression], + aggregations: Seq[NamedExpression], + groupByAliases: Seq[Alias], + groupingAttrs: Seq[Expression], + gid: Attribute): Seq[NamedExpression] = aggregations.map { + // collect all the found AggregateExpression, so we can check an expression is part of + // any AggregateExpression or not. + val aggsBuffer = ArrayBuffer[Expression]() + // Returns whether the expression belongs to any expressions in `aggsBuffer` or not. + def isPartOfAggregation(e: Expression): Boolean = { + aggsBuffer.exists(a => a.find(_ eq e).isDefined) + } + replaceGroupingFunc(_, groupByExprs, gid).transformDown { + // AggregateExpression should be computed on the unmodified value of its argument + // expressions, so we should not replace any references to grouping expression + // inside it. + case e: AggregateExpression => + aggsBuffer += e + e + case e if isPartOfAggregation(e) => e + case e => + // Replace expression by expand output attribute. + val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) + if (index == -1) { + e + } else { + groupingAttrs(index) } - replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown { - // AggregateExpression should be computed on the unmodified value of its argument - // expressions, so we should not replace any references to grouping expression - // inside it. - case e: AggregateExpression => - aggsBuffer += e - e - case e if isPartOfAggregation(e) => e - case e => - val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) - if (index == -1) { - e - } else { - groupingAttrs(index) - } - }.asInstanceOf[NamedExpression] - } + }.asInstanceOf[NamedExpression] + } - Aggregate(groupingAttrs, aggregations, expand) + /* + * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets. + */ + private def constructAggregate( + selectedGroupByExprs: Seq[Seq[Expression]], + groupByExprs: Seq[Expression], + aggregationExprs: Seq[NamedExpression], + child: LogicalPlan): LogicalPlan = { + val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() - case f @ Filter(cond, child) if hasGroupingFunction(cond) => - val groupingExprs = findGroupingExprs(child) - // The unresolved grouping id will be resolved by ResolveMissingReferences - val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute) - f.copy(condition = newCond) + // Expand works by setting grouping expressions to null as determined by the + // `selectedGroupByExprs`. To prevent these null values from being used in an aggregate + // instead of the original value we need to create new aliases for all group by expressions + // that will only be used for the intended purpose. + val groupByAliases = constructGroupByAlias(groupByExprs) - case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) => - val groupingExprs = findGroupingExprs(child) - val gid = VirtualColumn.groupingIdAttribute - // The unresolved grouping id will be resolved by ResolveMissingReferences - val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder]) - s.copy(order = newOrder) + val expand = constructExpand(selectedGroupByExprs, child, groupByAliases, gid) + val groupingAttrs = expand.output.drop(child.output.length) + + val aggregations = constructAggregateExprs( + groupByExprs, aggregationExprs, groupByAliases, groupingAttrs, gid) + + Aggregate(groupingAttrs, aggregations, expand) } private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = { @@ -369,6 +393,41 @@ class Analyzer( failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") } } + + // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case a if !a.childrenResolved => a // be sure all of the children are resolved. + case p if p.expressions.exists(hasGroupingAttribute) => + failAnalysis( + s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") + + // Ensure group by expressions and aggregate expressions have been resolved. + case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) + if (groupByExprs ++ aggregateExpressions).forall(_.resolved) => + constructAggregate(cubeExprs(groupByExprs), groupByExprs, aggregateExpressions, child) + case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) + if (groupByExprs ++ aggregateExpressions).forall(_.resolved) => + constructAggregate(rollupExprs(groupByExprs), groupByExprs, aggregateExpressions, child) + // Ensure all the expressions have been resolved. + case x: GroupingSets if x.expressions.forall(_.resolved) => + constructAggregate(x.selectedGroupByExprs, x.groupByExprs, x.aggregations, x.child) + + // We should make sure all expressions in condition have been resolved. + case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved => + val groupingExprs = findGroupingExprs(child) + // The unresolved grouping id will be resolved by ResolveMissingReferences + val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute) + f.copy(condition = newCond) + + // We should make sure all [[SortOrder]]s have been resolved. + case s @ Sort(order, _, child) + if order.exists(hasGroupingFunction) && order.forall(_.resolved) => + val groupingExprs = findGroupingExprs(child) + val gid = VirtualColumn.groupingIdAttribute + // The unresolved grouping id will be resolved by ResolveMissingReferences + val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder]) + s.copy(order = newOrder) + } } object ResolvePivot extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4b151c81d8f8b1bfd0369a999ca790fee15778a7..2c4db0d2c34257c3775a2c9e535cec064a842149 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -492,33 +492,18 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { ctx: AggregationContext, selectExpressions: Seq[NamedExpression], query: LogicalPlan): LogicalPlan = withOrigin(ctx) { - import ctx._ - val groupByExpressions = expressionList(groupingExpressions) + val groupByExpressions = expressionList(ctx.groupingExpressions) - if (GROUPING != null) { + if (ctx.GROUPING != null) { // GROUP BY .... GROUPING SETS (...) - val expressionMap = groupByExpressions.zipWithIndex.toMap - val numExpressions = expressionMap.size - val mask = (1 << numExpressions) - 1 - val masks = ctx.groupingSet.asScala.map { - _.expression.asScala.foldLeft(mask) { - case (bitmap, eCtx) => - // Find the index of the expression. - val e = typedVisit[Expression](eCtx) - val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse( - throw new ParseException( - s"$e doesn't show up in the GROUP BY list", ctx)) - // 0 means that the column at the given index is a grouping column, 1 means it is not, - // so we unset the bit in bitmap. - bitmap & ~(1 << (numExpressions - 1 - index)) - } - } - GroupingSets(masks, groupByExpressions, query, selectExpressions) + val selectedGroupByExprs = + ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e))) + GroupingSets(selectedGroupByExprs, groupByExpressions, query, selectExpressions) } else { // GROUP BY .... (WITH CUBE | WITH ROLLUP)? - val mappedGroupByExpressions = if (CUBE != null) { + val mappedGroupByExpressions = if (ctx.CUBE != null) { Seq(Cube(groupByExpressions)) - } else if (ROLLUP != null) { + } else if (ctx.ROLLUP != null) { Seq(Rollup(groupByExpressions)) } else { groupByExpressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 65ceab2ce27b154177a574904e5968f44c0e1ad0..dcae7b026f58c75ef690b0ee69efccebabc6bc83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import scala.collection.mutable.ArrayBuffer - +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTypes @@ -523,51 +522,56 @@ case class Window( object Expand { /** - * Extract attribute set according to the grouping id. + * Build bit mask from attributes of selected grouping set. A bit in the bitmask is corresponding + * to an attribute in group by attributes sequence, the selected attribute has corresponding bit + * set to 0 and otherwise set to 1. For example, if we have GroupBy attributes (a, b, c, d), the + * bitmask 5(whose binary form is 0101) represents grouping set (a, c). * - * @param bitmask bitmask to represent the selected of the attribute sequence - * @param attrs the attributes in sequence - * @return the attributes of non selected specified via bitmask (with the bit set to 1) + * @param groupingSetAttrs The attributes of selected grouping set + * @param attrMap Mapping group by attributes to its index in attributes sequence + * @return The bitmask which represents the selected attributes out of group by attributes. */ - private def buildNonSelectAttrSet( - bitmask: Int, - attrs: Seq[Attribute]): AttributeSet = { - val nonSelect = new ArrayBuffer[Attribute]() - - var bit = attrs.length - 1 - while (bit >= 0) { - if (((bitmask >> bit) & 1) == 1) nonSelect += attrs(attrs.length - bit - 1) - bit -= 1 - } - - AttributeSet(nonSelect) + private def buildBitmask( + groupingSetAttrs: Seq[Attribute], + attrMap: Map[Attribute, Int]): Int = { + val numAttributes = attrMap.size + val mask = (1 << numAttributes) - 1 + // Calculate the attrbute masks of selected grouping set. For example, if we have GroupBy + // attributes (a, b, c, d), grouping set (a, c) will produce the following sequence: + // (15, 7, 13), whose binary form is (1111, 0111, 1101) + val masks = (mask +: groupingSetAttrs.map(attrMap).map(index => + // 0 means that the column at the given index is a grouping column, 1 means it is not, + // so we unset the bit in bitmap. + ~(1 << (numAttributes - 1 - index)) + )) + // Reduce masks to generate an bitmask for the selected grouping set. + masks.reduce(_ & _) } /** * Apply the all of the GroupExpressions to every input row, hence we will get * multiple output rows for an input row. * - * @param bitmasks The bitmask set represents the grouping sets + * @param groupingSetsAttrs The attributes of grouping sets * @param groupByAliases The aliased original group by expressions * @param groupByAttrs The attributes of aliased group by expressions * @param gid Attribute of the grouping id * @param child Child operator */ def apply( - bitmasks: Seq[Int], + groupingSetsAttrs: Seq[Seq[Attribute]], groupByAliases: Seq[Alias], groupByAttrs: Seq[Attribute], gid: Attribute, child: LogicalPlan): Expand = { + val attrMap = groupByAttrs.zipWithIndex.toMap + // Create an array of Projections for the child projection, and replace the projections' // expressions which equal GroupBy expressions with Literal(null), if those expressions - // are not set for this grouping set (according to the bit mask). - val projections = bitmasks.map { bitmask => - // get the non selected grouping attributes according to the bit mask - val nonSelectedGroupAttrSet = buildNonSelectAttrSet(bitmask, groupByAttrs) - + // are not set for this grouping set. + val projections = groupingSetsAttrs.map { groupingSetAttrs => child.output ++ groupByAttrs.map { attr => - if (nonSelectedGroupAttrSet.contains(attr)) { + if (!groupingSetAttrs.contains(attr)) { // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null Literal.create(null, attr.dataType) @@ -575,7 +579,7 @@ object Expand { attr } // groupingId is the last output, here we use the bit mask as the concrete value for it. - } :+ Literal.create(bitmask, IntegerType) + } :+ Literal.create(buildBitmask(groupingSetAttrs, attrMap), IntegerType) } // the `groupByAttrs` has different meaning in `Expand.output`, it could be the original @@ -616,16 +620,15 @@ case class Expand( * * We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer * - * @param bitmasks A list of bitmasks, each of the bitmask indicates the selected - * GroupBy expressions - * @param groupByExprs The Group By expressions candidates, take effective only if the - * associated bit in the bitmask set to 1. + * @param selectedGroupByExprs A sequence of selected GroupBy expressions, all exprs should + * exists in groupByExprs. + * @param groupByExprs The Group By expressions candidates. * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions * will be considered as constant null if it appears in the expressions */ case class GroupingSets( - bitmasks: Seq[Int], + selectedGroupByExprs: Seq[Seq[Expression]], groupByExprs: Seq[Expression], child: LogicalPlan, aggregations: Seq[NamedExpression]) extends UnaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..2a0205bdc90fe043c21a11c33452a384d807d390 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ + +class ResolveGroupingAnalyticsSuite extends AnalysisTest { + + lazy val a = 'a.int + lazy val b = 'b.string + lazy val c = 'c.string + lazy val unresolved_a = UnresolvedAttribute("a") + lazy val unresolved_b = UnresolvedAttribute("b") + lazy val unresolved_c = UnresolvedAttribute("c") + lazy val gid = 'spark_grouping_id.int.withNullability(false) + lazy val hive_gid = 'grouping__id.int.withNullability(false) + lazy val grouping_a = Cast(ShiftRight(gid, 1) & 1, ByteType) + lazy val nulInt = Literal(null, IntegerType) + lazy val nulStr = Literal(null, StringType) + lazy val r1 = LocalRelation(a, b, c) + + test("rollupExprs") { + val testRollup = (exprs: Seq[Expression], rollup: Seq[Seq[Expression]]) => { + val result = SimpleAnalyzer.ResolveGroupingAnalytics.rollupExprs(exprs) + assert(result.sortBy(_.hashCode) == rollup.sortBy(_.hashCode)) + } + + testRollup(Seq(a, b, c), Seq(Seq(), Seq(a), Seq(a, b), Seq(a, b, c))) + testRollup(Seq(c, b, a), Seq(Seq(), Seq(c), Seq(c, b), Seq(c, b, a))) + testRollup(Seq(a), Seq(Seq(), Seq(a))) + testRollup(Seq(), Seq(Seq())) + } + + test("cubeExprs") { + val testCube = (exprs: Seq[Expression], cube: Seq[Seq[Expression]]) => { + val result = SimpleAnalyzer.ResolveGroupingAnalytics.cubeExprs(exprs) + assert(result.sortBy(_.hashCode) == cube.sortBy(_.hashCode)) + } + + testCube(Seq(a, b, c), + Seq(Seq(), Seq(a), Seq(b), Seq(c), Seq(a, b), Seq(a, c), Seq(b, c), Seq(a, b, c))) + testCube(Seq(c, b, a), + Seq(Seq(), Seq(a), Seq(b), Seq(c), Seq(c, b), Seq(c, a), Seq(b, a), Seq(c, b, a))) + testCube(Seq(a), Seq(Seq(), Seq(a))) + testCube(Seq(), Seq(Seq())) + } + + test("grouping sets") { + val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Seq(unresolved_a, unresolved_b), r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)))) + val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan, expected) + + val originalPlan2 = GroupingSets(Seq(), Seq(unresolved_a, unresolved_b), r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)))) + val expected2 = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), + Expand( + Seq(), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan2, expected2) + + val originalPlan3 = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b), + Seq(unresolved_c)), Seq(unresolved_a, unresolved_b), r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)))) + assertAnalysisError(originalPlan3, Seq("doesn't show up in the GROUP BY list")) + } + + test("cube") { + val originalPlan = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))), + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) + val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), + Expand( + Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), + Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan, expected) + + val originalPlan2 = Aggregate(Seq(Cube(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1) + val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")), + Expand( + Seq(Seq(a, b, c, 0)), + Seq(a, b, c, gid), + Project(Seq(a, b, c), r1))) + checkAnalysis(originalPlan2, expected2) + } + + test("rollup") { + val originalPlan = Aggregate(Seq(Rollup(Seq(unresolved_a, unresolved_b))), + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) + val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), + Expand( + Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan, expected) + + val originalPlan2 = Aggregate(Seq(Rollup(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1) + val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")), + Expand( + Seq(Seq(a, b, c, 0)), + Seq(a, b, c, gid), + Project(Seq(a, b, c), r1))) + checkAnalysis(originalPlan2, expected2) + } + + test("grouping function") { + // GrouingSets + val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Seq(unresolved_a, unresolved_b), r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), + UnresolvedAlias(Grouping(unresolved_a)))) + val expected = Aggregate(Seq(a, b, gid), + Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan, expected) + + // Cube + val originalPlan2 = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))), + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), + UnresolvedAlias(Grouping(unresolved_a))), r1) + val expected2 = Aggregate(Seq(a, b, gid), + Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")), + Expand( + Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), + Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan2, expected2) + + // Rollup + val originalPlan3 = Aggregate(Seq(Rollup(Seq(unresolved_a, unresolved_b))), + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), + UnresolvedAlias(Grouping(unresolved_a))), r1) + val expected3 = Aggregate(Seq(a, b, gid), + Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")), + Expand( + Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan3, expected3) + } + + test("grouping_id") { + // GrouingSets + val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Seq(unresolved_a, unresolved_b), r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), + UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b))))) + val expected = Aggregate(Seq(a, b, gid), + Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan, expected) + + // Cube + val originalPlan2 = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))), + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), + UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1) + val expected2 = Aggregate(Seq(a, b, gid), + Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")), + Expand( + Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), + Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan2, expected2) + + // Rollup + val originalPlan3 = Aggregate(Seq(Rollup(Seq(unresolved_a, unresolved_b))), + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), + UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1) + val expected3 = Aggregate(Seq(a, b, gid), + Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")), + Expand( + Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan3, expected3) + } + + test("filter with grouping function") { + // Filter with Grouping function + val originalPlan = Filter(Grouping(unresolved_a) === 0, + GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) + val expected = Project(Seq(a, b), Filter(Cast(grouping_a, IntegerType) === 0, + Aggregate(Seq(a, b, gid), + Seq(a, b, gid), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) + checkAnalysis(originalPlan, expected) + + val originalPlan2 = Filter(Grouping(unresolved_a) === 0, + Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1)) + assertAnalysisError(originalPlan2, + Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) + + // Filter with GroupingID + val originalPlan3 = Filter(GroupingID(Seq(unresolved_a, unresolved_b)) === 1, + GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) + val expected3 = Project(Seq(a, b), Filter(gid === 1, + Aggregate(Seq(a, b, gid), + Seq(a, b, gid), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) + checkAnalysis(originalPlan3, expected3) + + val originalPlan4 = Filter(GroupingID(Seq(unresolved_a)) === 1, + Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1)) + assertAnalysisError(originalPlan4, + Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) + } + + test("sort with grouping function") { + // Sort with Grouping function + val originalPlan = Sort( + Seq(SortOrder(Grouping(unresolved_a), Ascending)), true, + GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) + val expected = Project(Seq(a, b), Sort( + Seq(SortOrder('aggOrder.byte.withNullability(false), Ascending)), true, + Aggregate(Seq(a, b, gid), + Seq(a, b, grouping_a.as("aggOrder")), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) + checkAnalysis(originalPlan, expected) + + val originalPlan2 = Sort(Seq(SortOrder(Grouping(unresolved_a), Ascending)), true, + Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1)) + assertAnalysisError(originalPlan2, + Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) + + // Sort with GroupingID + val originalPlan3 = Sort( + Seq(SortOrder(GroupingID(Seq(unresolved_a, unresolved_b)), Ascending)), true, + GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) + val expected3 = Project(Seq(a, b), Sort( + Seq(SortOrder('aggOrder.int.withNullability(false), Ascending)), true, + Aggregate(Seq(a, b, gid), + Seq(a, b, gid.as("aggOrder")), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) + checkAnalysis(originalPlan3, expected3) + + val originalPlan4 = Sort( + Seq(SortOrder(GroupingID(Seq(unresolved_a)), Ascending)), true, + Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1)) + assertAnalysisError(originalPlan4, + Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 7400f3430e99c7c6d6986f1016b3e95d65467380..5f0f6ee479c69750a6e87e1830c46a19f524364e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -233,9 +233,8 @@ class PlanParserSuite extends PlanTest { // Grouping Sets assertEqual(s"$sql grouping sets((a, b), (a), ())", - GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) - intercept(s"$sql grouping sets((a, b), (c), ())", - "c doesn't show up in the GROUP BY list") + GroupingSets(Seq(Seq('a, 'b), Seq('a), Seq()), Seq('a, 'b), table("d"), + Seq('a, 'b, 'sum.function('c).as("c")))) } test("limit") {