Skip to content
Snippets Groups Projects
Commit ef362846 authored by Herman van Hovell's avatar Herman van Hovell Committed by Yin Huai
Browse files

[SPARK-9241][SQL] Supporting multiple DISTINCT columns - follow-up

This PR is a follow up for PR https://github.com/apache/spark/pull/9406. It adds more documentation to the rewriting rule, removes a redundant if expression in the non-distinct aggregation path and adds a multiple distinct test to the AggregationQuerySuite.

cc yhuai marmbrus

Author: Herman van Hovell <hvanhovell@questtec.nl>

Closes #9541 from hvanhovell/SPARK-9241-followup.
parent 2ff0e79a
No related branches found
No related tags found
No related merge requests found
......@@ -222,10 +222,76 @@ object Utils {
* aggregation in which the regular aggregation expressions and every distinct clause is aggregated
* in a separate group. The results are then combined in a second aggregate.
*
* TODO Expression cannocalization
* TODO Eliminate foldable expressions from distinct clauses.
* TODO This eliminates all distinct expressions. We could safely pass one to the aggregate
* operator. Perhaps this is a good thing? It is much simpler to plan later on...
* For example (in scala):
* {{{
* val data = Seq(
* ("a", "ca1", "cb1", 10),
* ("a", "ca1", "cb2", 5),
* ("b", "ca1", "cb1", 13))
* .toDF("key", "cat1", "cat2", "value")
* data.registerTempTable("data")
*
* val agg = data.groupBy($"key")
* .agg(
* countDistinct($"cat1").as("cat1_cnt"),
* countDistinct($"cat2").as("cat2_cnt"),
* sum($"value").as("total"))
* }}}
*
* This translates to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [COUNT(DISTINCT 'cat1),
* COUNT(DISTINCT 'cat2),
* sum('value)]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* LocalTableScan [...]
* }}}
*
* This rule rewrites this logical plan to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [count(if (('gid = 1)) 'cat1 else null),
* count(if (('gid = 2)) 'cat2 else null),
* first(if (('gid = 0)) 'total else null) ignore nulls]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* Aggregate(
* key = ['key, 'cat1, 'cat2, 'gid]
* functions = [sum('value)]
* output = ['key, 'cat1, 'cat2, 'gid, 'total])
* Expand(
* projections = [('key, null, null, 0, cast('value as bigint)),
* ('key, 'cat1, null, 1, null),
* ('key, null, 'cat2, 2, null)]
* output = ['key, 'cat1, 'cat2, 'gid, 'value])
* LocalTableScan [...]
* }}}
*
* The rule does the following things here:
* 1. Expand the data. There are three aggregation groups in this query:
* i. the non-distinct group;
* ii. the distinct 'cat1 group;
* iii. the distinct 'cat2 group.
* An expand operator is inserted to expand the child data for each group. The expand will null
* out all unused columns for the given group; this must be done in order to ensure correctness
* later on. Groups can by identified by a group id (gid) column added by the expand operator.
* 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of
* this aggregate consists of the original group by clause, all the requested distinct columns
* and the group id. Both de-duplication of distinct column and the aggregation of the
* non-distinct group take advantage of the fact that we group by the group id (gid) and that we
* have nulled out all non-relevant columns for the the given group.
* 3. Aggregating the distinct groups and combining this with the results of the non-distinct
* aggregation. In this step we use the group id to filter the inputs for the aggregate
* functions. The result of the non-distinct group are 'aggregated' by using the first operator,
* it might be more elegant to use the native UDAF merge mechanism for this in the future.
*
* This rule duplicates the input data by two or more times (# distinct groups + an optional
* non-distinct group). This will put quite a bit of memory pressure of the used aggregate and
* exchange operators. Keeping the number of distinct groups as low a possible should be priority,
* we could improve this in the current rule by applying more advanced expression cannocalization
* techniques.
*/
object MultipleDistinctRewriter extends Rule[LogicalPlan] {
......@@ -261,11 +327,10 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// Functions used to modify aggregate functions and their inputs.
def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
def patchAggregateFunctionChildren(
af: AggregateFunction2,
id: Literal,
attrs: Map[Expression, Expression]): AggregateFunction2 = {
af.withNewChildren(af.children.map { case afc =>
evalWithinGroup(id, attrs(afc))
af: AggregateFunction2)(
attrs: Expression => Expression): AggregateFunction2 = {
af.withNewChildren(af.children.map {
case afc => attrs(afc)
}).asInstanceOf[AggregateFunction2]
}
......@@ -288,7 +353,9 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// Final aggregate
val operators = expressions.map { e =>
val af = e.aggregateFunction
val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap)
val naf = patchAggregateFunctionChildren(af) { x =>
evalWithinGroup(id, distinctAggChildAttrMap(x))
}
(e, e.copy(aggregateFunction = naf, isDistinct = false))
}
......@@ -304,26 +371,27 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
val regularGroupId = Literal(0)
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
val af = patchAggregateFunctionChildren(
e.aggregateFunction,
regularGroupId,
regularAggChildAttrMap)
val a = Alias(e.copy(aggregateFunction = af), e.toString)()
// Get the result of the first aggregate in the last aggregate.
val b = AggregateExpression2(
aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)),
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap)
val operator = Alias(e.copy(aggregateFunction = af), e.toString)()
// Select the result of the first aggregate in the last aggregate.
val result = AggregateExpression2(
aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)),
mode = Complete,
isDistinct = false)
// Some aggregate functions (COUNT) have the special property that they can return a
// non-null result without any input. We need to make sure we return a result in this case.
val c = af.defaultResult match {
case Some(lit) => Coalesce(Seq(b, lit))
case None => b
val resultWithDefault = af.defaultResult match {
case Some(lit) => Coalesce(Seq(result, lit))
case None => result
}
(e, a, c)
// Return a Tuple3 containing:
// i. The original aggregate expression (used for look ups).
// ii. The actual aggregation operator (used in the first aggregate).
// iii. The operator that selects and returns the result (used in the second aggregate).
(e, operator, resultWithDefault)
}
// Construct the regular aggregate input projection only if we need one.
......
......@@ -516,6 +516,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(3, 4, 4, 3, null) :: Nil)
}
test("multiple distinct column sets") {
checkAnswer(
sqlContext.sql(
"""
|SELECT
| key,
| count(distinct value1),
| count(distinct value2)
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(null, 3, 3) ::
Row(1, 2, 3) ::
Row(2, 2, 1) ::
Row(3, 0, 1) :: Nil)
}
test("test count") {
checkAnswer(
sqlContext.sql(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment