Skip to content
Snippets Groups Projects
Commit f728e0fe authored by Cheng Hao's avatar Cheng Hao Committed by Michael Armbrust
Browse files

[SPARK-2663] [SQL] Support the Grouping Set

Add support for `GROUPING SETS`, `ROLLUP`, `CUBE` and the the virtual column `GROUPING__ID`.

More details on how to use the `GROUPING SETS" can be found at: https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation,+Cube,+Grouping+and+Rollup
https://issues.apache.org/jira/secure/attachment/12676811/grouping_set.pdf

The generic idea of the implementations are :
1 Replace the `ROLLUP`, `CUBE` with `GROUPING SETS`
2 Explode each of the input row, and then feed them to `Aggregate`
  * Each grouping set are represented as the bit mask for the `GroupBy Expression List`, for each bit, `1` means the expression is selected, otherwise `0` (left is the lower bit, and right is the higher bit in the `GroupBy Expression List`)
  * Several of projections are constructed according to the grouping sets, and within each projection(Seq[Expression), we replace those expressions with `Literal(null)` if it's not selected in the grouping set (based on the bit mask)
  * Output Schema of `Explode` is `child.output :+ grouping__id`
  * GroupBy Expressions of `Aggregate` is `GroupBy Expression List :+ grouping__id`
  * Keep the `Aggregation expressions` the same for the `Aggregate`

The expressions substitutions happen in Logic Plan analyzing, so we will benefit from the Logical Plan optimization (e.g. expression constant folding, and map side aggregation etc.), Only an `Explosive` operator added for Physical Plan, which will explode the rows according the pre-set projections.

A known issue will be done in the follow up PR:
* Optimization `ColumnPruning` is not supported yet for `Explosive` node.

Author: Cheng Hao <hao.cheng@intel.com>

Closes #1567 from chenghao-intel/grouping_sets and squashes the following commits:

fe65fcc [Cheng Hao] Remove the extra space
3547056 [Cheng Hao] Add more doc and Simplify the Expand
a7c869d [Cheng Hao] update code as feedbacks
d23c672 [Cheng Hao] Add GroupingExpression to replace the Seq[Expression]
414b165 [Cheng Hao] revert the unnecessary changes
ec276c6 [Cheng Hao] Support Rollup/Cube/GroupingSets
parent 9804a759
No related branches found
No related tags found
No related merge requests found
Showing
with 415 additions and 11 deletions
......@@ -17,11 +17,13 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.catalyst.types.IntegerType
/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
......@@ -56,6 +58,7 @@ class Analyzer(catalog: Catalog,
Batch("Resolution", fixedPoint,
ResolveReferences ::
ResolveRelations ::
ResolveGroupingAnalytics ::
ResolveSortReferences ::
NewRelationInstances ::
ImplicitGenerate ::
......@@ -102,6 +105,93 @@ class Analyzer(catalog: Catalog,
}
}
object ResolveGroupingAnalytics extends Rule[LogicalPlan] {
/**
* Extract attribute set according to the grouping id
* @param bitmask bitmask to represent the selected of the attribute sequence
* @param exprs the attributes in sequence
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
*/
private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
: OpenHashSet[Expression] = {
val set = new OpenHashSet[Expression](2)
var bit = exprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
bit -= 1
}
set
}
/*
* GROUP BY a, b, c, WITH ROLLUP
* is equivalent to
* GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (a), ( )).
* Group Count: N + 1 (N is the number of group expression)
*
* We need to get all of its subsets for the rule described above, the subset is
* represented as the bit masks.
*/
def bitmasks(r: Rollup): Seq[Int] = {
Seq.tabulate(r.groupByExprs.length + 1)(idx => {(1 << idx) - 1})
}
/*
* GROUP BY a, b, c, WITH CUBE
* is equivalent to
* GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (b, c), (a, c), (a), (b), (c), ( ) ).
* Group Count: 2^N (N is the number of group expression)
*
* We need to get all of its sub sets for a given GROUPBY expressions, the subset is
* represented as the bit masks.
*/
def bitmasks(c: Cube): Seq[Int] = {
Seq.tabulate(1 << c.groupByExprs.length)(i => i)
}
/**
* Create an array of Projections for the child projection, and replace the projections'
* expressions which equal GroupBy expressions with Literal(null), if those expressions
* are not set for this grouping set (according to the bit mask).
*/
private[this] def expand(g: GroupingSets): Seq[GroupExpression] = {
val result = new scala.collection.mutable.ArrayBuffer[GroupExpression]
g.bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs)
val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal(null, expr.dataType)
case x if x == g.gid =>
// replace the groupingId with concrete value (the bit mask)
Literal(bitmask, IntegerType)
})
result += GroupExpression(substitution)
}
result.toSeq
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a: Cube if a.resolved =>
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
case a: Rollup if a.resolved =>
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
case x: GroupingSets if x.resolved =>
Aggregate(
x.groupByExprs :+ x.gid,
x.aggregations,
Expand(expand(x), x.child.output :+ x.gid, x.child))
}
}
/**
* Checks for non-aggregated attributes with aggregation
*/
......@@ -183,6 +273,11 @@ class Analyzer(catalog: Catalog,
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressions {
case u @ UnresolvedAttribute(name)
if resolver(name, VirtualColumn.groupingIdName) &&
q.isInstanceOf[GroupingAnalytics] =>
// Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics
q.asInstanceOf[GroupingAnalytics].gid
case u @ UnresolvedAttribute(name) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result = q.resolveChildren(name, resolver).getOrElse(u)
......
......@@ -284,6 +284,17 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>
}
// TODO Semantically we probably not need GroupExpression
// All we need is holding the Seq[Expression], and ONLY used in doing the
// expressions transformation correctly. Probably will be removed since it's
// not like a real expressions.
case class GroupExpression(children: Seq[Expression]) extends Expression {
self: Product =>
type EvaluatedType = Seq[Any]
override def eval(input: Row): EvaluatedType = ???
override def nullable = false
override def foldable = false
override def dataType = ???
}
......@@ -187,3 +187,8 @@ case class AttributeReference(
override def toString: String = s"$name#${exprId.id}$typeSuffix"
}
object VirtualColumn {
val groupingIdName = "grouping__id"
def newGroupingId = AttributeReference(groupingIdName, IntegerType, false)()
}
......@@ -143,6 +143,89 @@ case class Aggregate(
override def output = aggregateExpressions.map(_.toAttribute)
}
/**
* Apply the all of the GroupExpressions to every input row, hence we will get
* multiple output rows for a input row.
* @param projections The group of expressions, all of the group expressions should
* output the same schema specified by the parameter `output`
* @param output The output Schema
* @param child Child operator
*/
case class Expand(
projections: Seq[GroupExpression],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode
trait GroupingAnalytics extends UnaryNode {
self: Product =>
def gid: AttributeReference
def groupByExprs: Seq[Expression]
def aggregations: Seq[NamedExpression]
override def output = aggregations.map(_.toAttribute)
}
/**
* A GROUP BY clause with GROUPING SETS can generate a result set equivalent
* to generated by a UNION ALL of multiple simple GROUP BY clauses.
*
* We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer
* @param bitmasks A list of bitmasks, each of the bitmask indicates the selected
* GroupBy expressions
* @param groupByExprs The Group By expressions candidates, take effective only if the
* associated bit in the bitmask set to 1.
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
* will be considered as constant null if it appears in the expressions
* @param gid The attribute represents the virtual column GROUPING__ID, and it's also
* the bitmask indicates the selected GroupBy Expressions for each
* aggregating output row.
* The associated output will be one of the value in `bitmasks`
*/
case class GroupingSets(
bitmasks: Seq[Int],
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
/**
* Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
* and eventually will be transformed to Aggregate(.., Expand) in Analyzer
*
* @param groupByExprs The Group By expressions candidates.
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
* will be considered as constant null if it appears in the expressions
* @param gid The attribute represents the virtual column GROUPING__ID, and it's also
* the bitmask indicates the selected GroupBy Expressions for each
* aggregating output row.
*/
case class Cube(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
/**
* Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
* and eventually will be transformed to Aggregate(.., Expand) in Analyzer
*
* @param groupByExprs The Group By expressions candidates, take effective only if the
* associated bit in the bitmask set to 1.
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
* will be considered as constant null if it appears in the expressions
* @param gid The attribute represents the virtual column GROUPING__ID, and it's also
* the bitmask indicates the selected GroupBy Expressions for each
* aggregating output row.
*/
case class Rollup(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output = child.output
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{UnknownPartitioning, Partitioning}
/**
* Apply the all of the GroupExpressions to every input row, hence we will get
* multiple output rows for a input row.
* @param projections The group of expressions, all of the group expressions should
* output the same schema specified bye the parameter `output`
* @param output The output Schema
* @param child Child operator
*/
@DeveloperApi
case class Expand(
projections: Seq[GroupExpression],
output: Seq[Attribute],
child: SparkPlan)
extends UnaryNode {
// The GroupExpressions can output data with arbitrary partitioning, so set it
// as UNKNOWN partitioning
override def outputPartitioning: Partitioning = UnknownPartitioning(0)
override def execute() = attachTree(this, "execute") {
child.execute().mapPartitions { iter =>
// TODO Move out projection objects creation and transfer to
// workers via closure. However we can't assume the Projection
// is serializable because of the code gen, so we have to
// create the projections within each of the partition processing.
val groups = projections.map(ee => newProjection(ee.children, child.output)).toArray
new Iterator[Row] {
private[this] var result: Row = _
private[this] var idx = -1 // -1 means the initial state
private[this] var input: Row = _
override final def hasNext = (-1 < idx && idx < groups.length) || iter.hasNext
override final def next(): Row = {
if (idx <= 0) {
// in the initial (-1) or beginning(0) of a new input row, fetch the next input tuple
input = iter.next()
idx = 0
}
result = groups(idx)(input)
idx += 1
if (idx == groups.length && iter.hasNext) {
idx = 0
}
result
}
}
}
}
}
......@@ -270,6 +270,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Project(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
case logical.Expand(projections, output, child) =>
execution.Expand(projections, output, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
......
......@@ -403,6 +403,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"groupby11",
"groupby12",
"groupby1_limit",
"groupby_grouping_id1",
"groupby_grouping_id2",
"groupby_grouping_sets1",
"groupby_grouping_sets2",
"groupby_grouping_sets3",
"groupby_grouping_sets4",
"groupby_grouping_sets5",
"groupby1_map",
"groupby1_map_nomap",
"groupby1_map_skew",
......
......@@ -393,6 +393,42 @@ private[hive] object HiveQl {
(db, tableName)
}
/**
* SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2))
* is equivalent to
* SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2
* Check the following link for details.
*
https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup
*
* The bitmask denotes the grouping expressions validity for a grouping set,
* the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive)
* e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of
* GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively.
*/
protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = {
val (keyASTs, setASTs) = children.partition( n => n match {
case Token("TOK_GROUPING_SETS_EXPRESSION", children) => false // grouping sets
case _ => true // grouping keys
})
val keys = keyASTs.map(nodeToExpr).toSeq
val keyMap = keyASTs.map(_.toStringTree).zipWithIndex.toMap
val bitmasks: Seq[Int] = setASTs.map(set => set match {
case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0
case Token("TOK_GROUPING_SETS_EXPRESSION", children) =>
children.foldLeft(0)((bitmap, col) => {
val colString = col.asInstanceOf[ASTNode].toStringTree()
require(keyMap.contains(colString), s"$colString doens't show up in the GROUP BY list")
bitmap | 1 << keyMap(colString)
})
case _ => sys.error("Expect GROUPING SETS clause")
})
(keys, bitmasks)
}
protected def nodeToPlan(node: Node): LogicalPlan = node match {
// Special drop table that also uncaches.
case Token("TOK_DROPTABLE",
......@@ -520,6 +556,9 @@ private[hive] object HiveQl {
selectDistinctClause ::
whereClause ::
groupByClause ::
rollupGroupByClause ::
cubeGroupByClause ::
groupingSetsClause ::
orderByClause ::
havingClause ::
sortByClause ::
......@@ -535,6 +574,9 @@ private[hive] object HiveQl {
"TOK_SELECTDI",
"TOK_WHERE",
"TOK_GROUPBY",
"TOK_ROLLUP_GROUPBY",
"TOK_CUBE_GROUPBY",
"TOK_GROUPING_SETS",
"TOK_ORDERBY",
"TOK_HAVING",
"TOK_SORTBY",
......@@ -603,16 +645,33 @@ private[hive] object HiveQl {
// The projection of the query can either be a normal projection, an aggregation
// (if there is a group by) or a script transformation.
val withProject = transformation.getOrElse {
// Not a transformation so must be either project or aggregation.
val selectExpressions = nameExpressions(select.getChildren.flatMap(selExprNodeToExpr))
groupByClause match {
case Some(groupBy) =>
Aggregate(groupBy.getChildren.map(nodeToExpr), selectExpressions, withLateralView)
case None =>
Project(selectExpressions, withLateralView)
}
val withProject: LogicalPlan = transformation.getOrElse {
val selectExpressions =
nameExpressions(select.getChildren.flatMap(selExprNodeToExpr).toSeq)
Seq(
groupByClause.map(e => e match {
case Token("TOK_GROUPBY", children) =>
// Not a transformation so must be either project or aggregation.
Aggregate(children.map(nodeToExpr), selectExpressions, withLateralView)
case _ => sys.error("Expect GROUP BY")
}),
groupingSetsClause.map(e => e match {
case Token("TOK_GROUPING_SETS", children) =>
val(groupByExprs, masks) = extractGroupingSet(children)
GroupingSets(masks, groupByExprs, withLateralView, selectExpressions)
case _ => sys.error("Expect GROUPING SETS")
}),
rollupGroupByClause.map(e => e match {
case Token("TOK_ROLLUP_GROUPBY", children) =>
Rollup(children.map(nodeToExpr), withLateralView, selectExpressions)
case _ => sys.error("Expect WITH ROLLUP")
}),
cubeGroupByClause.map(e => e match {
case Token("TOK_CUBE_GROUPBY", children) =>
Cube(children.map(nodeToExpr), withLateralView, selectExpressions)
case _ => sys.error("Expect WITH CUBE")
}),
Some(Project(selectExpressions, withLateralView))).flatten.head
}
val withDistinct =
......
NULL NULL 0
NULL 11 2
NULL 12 2
NULL 13 2
NULL 17 2
NULL 18 2
NULL 28 2
1 NULL 1
1 11 3
2 NULL 1
2 12 3
3 NULL 1
3 13 3
7 NULL 1
7 17 3
8 NULL 1
8 18 3
8 28 3
0 NULL NULL
1 1 NULL
3 1 11
1 2 NULL
3 2 12
1 3 NULL
3 3 13
1 7 NULL
3 7 17
1 8 NULL
3 8 18
3 8 28
NULL NULL 0 0
NULL 11 2 2
NULL 12 2 2
NULL 13 2 2
NULL 17 2 2
NULL 18 2 2
NULL 28 2 2
1 NULL 1 1
1 11 3 3
2 NULL 1 1
2 12 3 3
3 NULL 1 1
3 13 3 3
7 NULL 1 1
7 17 3 3
8 NULL 1 1
8 18 3 3
8 28 3 3
NULL NULL 0 6
1 NULL 1 2
1 NULL 3 1
1 1 3 1
2 NULL 1 1
2 2 3 1
3 NULL 1 2
3 NULL 3 1
3 3 3 1
4 NULL 1 1
4 5 3 1
0 1
1 4
3 6
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment