diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 54df96cd2446aa6328b221ea1415fb526eacdcf7..ec0c8b483a909e77c41e778e897c0be32ba0d406 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -49,4 +49,6 @@ case class Count(child: Expression) extends DeclarativeAggregate { ) override val evaluateExpression = Cast(count, LongType) + + override def defaultResult: Option[Literal] = Option(Literal(0L)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index 644c6211d5f311f53320b5160b8c93fb8486434f..39010c3be6d4eec412bfb8146be6dd3341e1d082 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.types.{StructType, MapType, ArrayType} +import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType} /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -41,7 +42,7 @@ object Utils { private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { case p: Aggregate if supportsGroupingKeySchema(p) => - val converted = p.transformExpressionsDown { + val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown { case expressions.Average(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Average(child), @@ -144,7 +145,8 @@ object Utils { aggregateFunction = aggregate.VarianceSamp(child), mode = aggregate.Complete, isDistinct = false) - } + }) + // Check if there is any expressions.AggregateExpression1 left. // If so, we cannot convert this plan. val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => @@ -156,6 +158,7 @@ object Utils { } // Check if there are multiple distinct columns. + // TODO remove this. val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => expr.collect { case agg: AggregateExpression2 => agg @@ -213,3 +216,178 @@ object Utils { case other => None } } + +/** + * This rule rewrites an aggregate query with multiple distinct clauses into an expanded double + * aggregation in which the regular aggregation expressions and every distinct clause is aggregated + * in a separate group. The results are then combined in a second aggregate. + * + * TODO Expression cannocalization + * TODO Eliminate foldable expressions from distinct clauses. + * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate + * operator. Perhaps this is a good thing? It is much simpler to plan later on... + */ +object MultipleDistinctRewriter extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case a: Aggregate => rewrite(a) + case p => p + } + + def rewrite(a: Aggregate): Aggregate = { + + // Collect all aggregate expressions. + val aggExpressions = a.aggregateExpressions.flatMap { e => + e.collect { + case ae: AggregateExpression2 => ae + } + } + + // Extract distinct aggregate expressions. + val distinctAggGroups = aggExpressions + .filter(_.isDistinct) + .groupBy(_.aggregateFunction.children.toSet) + + // Only continue to rewrite if there is more than one distinct group. + if (distinctAggGroups.size > 1) { + // Create the attributes for the grouping id and the group by clause. + val gid = new AttributeReference("gid", IntegerType, false)() + val groupByMap = a.groupingExpressions.collect { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> new AttributeReference(e.prettyName, e.dataType, e.nullable)() + } + val groupByAttrs = groupByMap.map(_._2) + + // Functions used to modify aggregate functions and their inputs. + def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) + def patchAggregateFunctionChildren( + af: AggregateFunction2, + id: Literal, + attrs: Map[Expression, Expression]): AggregateFunction2 = { + af.withNewChildren(af.children.map { case afc => + evalWithinGroup(id, attrs(afc)) + }).asInstanceOf[AggregateFunction2] + } + + // Setup unique distinct aggregate children. + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap + val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq + + // Setup expand & aggregate operators for distinct aggregate expressions. + val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { + case ((group, expressions), i) => + val id = Literal(i + 1) + + // Expand projection + val projection = distinctAggChildren.map { + case e if group.contains(e) => e + case e => nullify(e) + } :+ id + + // Final aggregate + val operators = expressions.map { e => + val af = e.aggregateFunction + val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap) + (e, e.copy(aggregateFunction = naf, isDistinct = false)) + } + + (projection, operators) + } + + // Setup expand for the 'regular' aggregate expressions. + val regularAggExprs = aggExpressions.filter(!_.isDistinct) + val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap + + // Setup aggregates for 'regular' aggregate expressions. + val regularGroupId = Literal(0) + val regularAggOperatorMap = regularAggExprs.map { e => + // Perform the actual aggregation in the initial aggregate. + val af = patchAggregateFunctionChildren( + e.aggregateFunction, + regularGroupId, + regularAggChildAttrMap) + val a = Alias(e.copy(aggregateFunction = af), e.toString)() + + // Get the result of the first aggregate in the last aggregate. + val b = AggregateExpression2( + aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)), + mode = Complete, + isDistinct = false) + + // Some aggregate functions (COUNT) have the special property that they can return a + // non-null result without any input. We need to make sure we return a result in this case. + val c = af.defaultResult match { + case Some(lit) => Coalesce(Seq(b, lit)) + case None => b + } + + (e, a, c) + } + + // Construct the regular aggregate input projection only if we need one. + val regularAggProjection = if (regularAggExprs.nonEmpty) { + Seq(a.groupingExpressions ++ + distinctAggChildren.map(nullify) ++ + Seq(regularGroupId) ++ + regularAggChildren) + } else { + Seq.empty[Seq[Expression]] + } + + // Construct the distinct aggregate input projections. + val regularAggNulls = regularAggChildren.map(nullify) + val distinctAggProjections = distinctAggOperatorMap.map { + case (projection, _) => + a.groupingExpressions ++ + projection ++ + regularAggNulls + } + + // Construct the expand operator. + val expand = Expand( + regularAggProjection ++ distinctAggProjections, + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq, + a.child) + + // Construct the first aggregate operator. This de-duplicates the all the children of + // distinct operators, and applies the regular aggregate operators. + val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid + val firstAggregate = Aggregate( + firstAggregateGroupBy, + firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), + expand) + + // Construct the second aggregate + val transformations: Map[Expression, Expression] = + (distinctAggOperatorMap.flatMap(_._2) ++ + regularAggOperatorMap.map(e => (e._1, e._3))).toMap + + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case e: Expression => + // The same GROUP BY clauses can have different forms (different names for instance) in + // the groupBy and aggregate expressions of an aggregate. This makes a map lookup + // tricky. So we do a linear search for a semantically equal group by expression. + groupByMap + .find(ge => e.semanticEquals(ge._1)) + .map(_._2) + .getOrElse(transformations.getOrElse(e, e)) + }.asInstanceOf[NamedExpression] + } + Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) + } else { + a + } + } + + private def nullify(e: Expression) = Literal.create(null, e.dataType) + + private def expressionAttributePair(e: Expression) = + // We are creating a new reference here instead of reusing the attribute in case of a + // NamedExpression. This is done to prevent collisions between distinct and regular aggregate + // children, in this case attribute reuse causes the input of the regular aggregate to bound to + // the (nulled out) input of the distinct aggregate. + e -> new AttributeReference(e.prettyName, e.dataType, true)() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index a2fab258fcac3cf16f4c0ea079be9a59b8469c4a..5c5b3d1ccd3cd0c677cc4d0930352602e15f2b24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -133,6 +133,12 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp */ def supportsPartial: Boolean = true + /** + * Result of the aggregate function when the input is empty. This is currently only used for the + * proper rewriting of distinct aggregate functions. + */ + def defaultResult: Option[Literal] = None + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 338c5193cb7a2d25ee22736e404efd2dd85ff37a..d222dfa33ad8ab5f2a904ed6d98c7bf8544a5cab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -200,9 +200,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child)) - if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty => - a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references))) + case a @ Aggregate(_, _, e @ Expand(_, _, child)) + if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty => + a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) ++ a.references))) // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 4cb67aacf33ee703da6629345bd9e5bcd9d83a1b..fb963e2f8f7e7f0ab8d43a72ce9125f8ea14af5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -235,33 +235,17 @@ case class Window( projectList ++ windowExpressions.map(_.toAttribute) } -/** - * Apply the all of the GroupExpressions to every input row, hence we will get - * multiple output rows for a input row. - * @param bitmasks The bitmask set represents the grouping sets - * @param groupByExprs The grouping by expressions - * @param child Child operator - */ -case class Expand( - bitmasks: Seq[Int], - groupByExprs: Seq[Expression], - gid: Attribute, - child: LogicalPlan) extends UnaryNode { - override def statistics: Statistics = { - val sizeInBytes = child.statistics.sizeInBytes * projections.length - Statistics(sizeInBytes = sizeInBytes) - } - - val projections: Seq[Seq[Expression]] = expand() - +private[sql] object Expand { /** - * Extract attribute set according to the grouping id + * Extract attribute set according to the grouping id. + * * @param bitmask bitmask to represent the selected of the attribute sequence * @param exprs the attributes in sequence * @return the attributes of non selected specified via bitmask (with the bit set to 1) */ - private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) - : OpenHashSet[Expression] = { + private def buildNonSelectExprSet( + bitmask: Int, + exprs: Seq[Expression]): OpenHashSet[Expression] = { val set = new OpenHashSet[Expression](2) var bit = exprs.length - 1 @@ -274,18 +258,28 @@ case class Expand( } /** - * Create an array of Projections for the child projection, and replace the projections' - * expressions which equal GroupBy expressions with Literal(null), if those expressions - * are not set for this grouping set (according to the bit mask). + * Apply the all of the GroupExpressions to every input row, hence we will get + * multiple output rows for a input row. + * + * @param bitmasks The bitmask set represents the grouping sets + * @param groupByExprs The grouping by expressions + * @param gid Attribute of the grouping id + * @param child Child operator */ - private[this] def expand(): Seq[Seq[Expression]] = { - val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] - - bitmasks.foreach { bitmask => + def apply( + bitmasks: Seq[Int], + groupByExprs: Seq[Expression], + gid: Attribute, + child: LogicalPlan): Expand = { + // Create an array of Projections for the child projection, and replace the projections' + // expressions which equal GroupBy expressions with Literal(null), if those expressions + // are not set for this grouping set (according to the bit mask). + val projections = bitmasks.map { bitmask => // get the non selected grouping attributes according to the bit mask val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs) - val substitution = (child.output :+ gid).map(expr => expr transformDown { + (child.output :+ gid).map(expr => expr transformDown { + // TODO this causes a problem when a column is used both for grouping and aggregation. case x: Expression if nonSelectedGroupExprSet.contains(x) => // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null @@ -294,15 +288,29 @@ case class Expand( // replace the groupingId with concrete value (the bit mask) Literal.create(bitmask, IntegerType) }) - - result += substitution } - - result.toSeq + Expand(projections, child.output :+ gid, child) } +} - override def output: Seq[Attribute] = { - child.output :+ gid +/** + * Apply a number of projections to every input row, hence we will get multiple output rows for + * a input row. + * + * @param projections to apply + * @param output of all projections. + * @param child operator. + */ +case class Expand( + projections: Seq[Seq[Expression]], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + + override def statistics: Statistics = { + // TODO shouldn't we factor in the size of the projection versus the size of the backing child + // row? + val sizeInBytes = child.statistics.sizeInBytes * projections.length + Statistics(sizeInBytes = sizeInBytes) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f4464e0b916f8f2d7512780aeb664d6860e8cdfd..dd3bb33c5728704236ef1ef478e07ccb903ec791 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -420,7 +420,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil - case e @ logical.Expand(_, _, _, child) => + case e @ logical.Expand(_, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil case a @ logical.Aggregate(group, agg, child) => { val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled