From 1ebd41b141a95ec264bd2dd50f0fe24cd459035d Mon Sep 17 00:00:00 2001 From: Yin Huai <yhuai@databricks.com> Date: Mon, 3 Aug 2015 00:23:08 -0700 Subject: [PATCH] [SPARK-9240] [SQL] Hybrid aggregate operator using unsafe row This PR adds a base aggregation iterator `AggregationIterator`, which is used to create `SortBasedAggregationIterator` (for sort-based aggregation) and `UnsafeHybridAggregationIterator` (first it tries hash-based aggregation and falls back to the sort-based aggregation (using external sorter) if we cannot allocate memory for the map). With these two iterators, we will not need existing iterators and I am removing those. Also, we can use a single physical `Aggregate` operator and it internally determines what iterators to used. https://issues.apache.org/jira/browse/SPARK-9240 Author: Yin Huai <yhuai@databricks.com> Closes #7813 from yhuai/AggregateOperator and squashes the following commits: e317e2b [Yin Huai] Remove unnecessary change. 74d93c5 [Yin Huai] Merge remote-tracking branch 'upstream/master' into AggregateOperator ba6afbc [Yin Huai] Add a little bit more comments. c9cf3b6 [Yin Huai] update 0f1b06f [Yin Huai] Remove unnecessary code. 21fd15f [Yin Huai] Remove unnecessary change. 964f88b [Yin Huai] Implement fallback strategy. b1ea5cf [Yin Huai] wip 7fcbd87 [Yin Huai] Add a flag to control what iterator to use. 533d5b2 [Yin Huai] Prepare for fallback! 33b7022 [Yin Huai] wip bd9282b [Yin Huai] UDAFs now supports UnsafeRow. f52ee53 [Yin Huai] wip 3171f44 [Yin Huai] wip d2c45a0 [Yin Huai] wip f60cc83 [Yin Huai] Also check input schema. af32210 [Yin Huai] Check iter.hasNext before we create an iterator because the constructor of the iterato will read at least one row from a non-empty input iter. 299008c [Yin Huai] First round cleanup. 3915bac [Yin Huai] Create a base iterator class for aggregation iterators and add the initial version of the hybrid iterator. --- .../expressions/aggregate/interfaces.scala | 19 +- .../sql/execution/aggregate/Aggregate.scala | 182 +++++ .../aggregate/AggregationIterator.scala | 490 +++++++++++++ .../SortBasedAggregationIterator.scala | 236 +++++++ .../UnsafeHybridAggregationIterator.scala | 398 +++++++++++ .../aggregate/aggregateOperators.scala | 175 ----- .../aggregate/sortBasedIterators.scala | 664 ------------------ .../spark/sql/execution/aggregate/udaf.scala | 269 ++++++- .../spark/sql/execution/aggregate/utils.scala | 99 +-- .../spark/sql/execution/basicOperators.scala | 1 - .../org/apache/spark/sql/SQLQuerySuite.scala | 10 +- .../execution/SparkSqlSerializer2Suite.scala | 9 +- .../execution/AggregationQuerySuite.scala | 118 ++-- 13 files changed, 1697 insertions(+), 973 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index d08f553cef..4abfdfe87d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -110,7 +110,11 @@ abstract class AggregateFunction2 * buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)` * will be 2. */ - var mutableBufferOffset: Int = 0 + protected var mutableBufferOffset: Int = 0 + + def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = { + mutableBufferOffset = newMutableBufferOffset + } /** * The offset of this function's start buffer value in the @@ -126,7 +130,11 @@ abstract class AggregateFunction2 * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)` * will be 3 (position 0 is used for the value of key`). */ - var inputBufferOffset: Int = 0 + protected var inputBufferOffset: Int = 0 + + def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = { + inputBufferOffset = newInputBufferOffset + } /** The schema of the aggregation buffer. */ def bufferSchema: StructType @@ -195,11 +203,8 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes) override def initialize(buffer: MutableRow): Unit = { - var i = 0 - while (i < bufferAttributes.size) { - buffer(i + mutableBufferOffset) = initialValues(i).eval() - i += 1 - } + throw new UnsupportedOperationException( + "AlgebraicAggregate's initialize should not be called directly") } override final def update(buffer: MutableRow, input: InternalRow): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala new file mode 100644 index 0000000000..cf568dc048 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} +import org.apache.spark.sql.types.StructType + +/** + * An Aggregate Operator used to evaluate [[AggregateFunction2]]. Based on the data types + * of the grouping expressions and aggregate functions, it determines if it uses + * sort-based aggregation and hybrid (hash-based with sort-based as the fallback) to + * process input rows. + */ +case class Aggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + private[this] val allAggregateExpressions = + nonCompleteAggregateExpressions ++ completeAggregateExpressions + + private[this] val hasNonAlgebricAggregateFunctions = + !allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) + + // Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of + // grouping key and aggregation buffer is supported; and (3) all + // aggregate functions are algebraic. + private[this] val supportsHybridIterator: Boolean = { + val aggregationBufferSchema: StructType = + StructType.fromAttributes( + allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) + val groupKeySchema: StructType = + StructType.fromAttributes(groupingExpressions.map(_.toAttribute)) + + val schemaSupportsUnsafe: Boolean = + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeProjection.canSupport(groupKeySchema) + + // TODO: Use the hybrid iterator for non-algebric aggregate functions. + sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && !hasNonAlgebricAggregateFunctions + } + + // We need to use sorted input if we have grouping expressions, and + // we cannot use the hybrid iterator or the hybrid is disabled. + private[this] val requiresSortedInput: Boolean = { + groupingExpressions.nonEmpty && !supportsHybridIterator + } + + override def canProcessUnsafeRows: Boolean = !hasNonAlgebricAggregateFunctions + + // If result expressions' data types are all fixed length, we generate unsafe rows + // (We have this requirement instead of check the result of UnsafeProjection.canSupport + // is because we use a mutable projection to generate the result). + override def outputsUnsafeRows: Boolean = { + // resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength) + // TODO: Supports generating UnsafeRows. We can just re-enable the line above and fix + // any issue we get. + false + } + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + if (requiresSortedInput) { + // TODO: We should not sort the input rows if they are just in reversed order. + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + } else { + Seq.fill(children.size)(Nil) + } + } + + override def outputOrdering: Seq[SortOrder] = { + if (requiresSortedInput) { + // It is possible that the child.outputOrdering starts with the required + // ordering expressions (e.g. we require [a] as the sort expression and the + // child's outputOrdering is [a, b]). We can only guarantee the output rows + // are sorted by values of groupingExpressions. + groupingExpressions.map(SortOrder(_, Ascending)) + } else { + Nil + } + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + // Because the constructor of an aggregation iterator will read at least the first row, + // we need to get the value of iter.hasNext first. + val hasInput = iter.hasNext + val useHybridIterator = + hasInput && + supportsHybridIterator && + groupingExpressions.nonEmpty + if (useHybridIterator) { + UnsafeHybridAggregationIterator.createFromInputIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection _, + child.output, + iter, + outputsUnsafeRows) + } else { + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator[InternalRow]() + } else { + val outputIter = SortBasedAggregationIterator.createFromInputIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection _ , + newProjection _, + child.output, + iter, + outputsUnsafeRows) + if (!hasInput && groupingExpressions.isEmpty) { + // There is no input and there is no grouping expressions. + // We need to output a single row as the output. + Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + } else { + outputIter + } + } + } + } + } + + override def simpleString: String = { + val iterator = if (supportsHybridIterator && groupingExpressions.nonEmpty) { + classOf[UnsafeHybridAggregationIterator].getSimpleName + } else { + classOf[SortBasedAggregationIterator].getSimpleName + } + + s"""NewAggregate with $iterator ${groupingExpressions} ${allAggregateExpressions}""" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala new file mode 100644 index 0000000000..abca373b0c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -0,0 +1,490 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.unsafe.KVIterator + +import scala.collection.mutable.ArrayBuffer + +/** + * The base class of [[SortBasedAggregationIterator]] and [[UnsafeHybridAggregationIterator]]. + * It mainly contains two parts: + * 1. It initializes aggregate functions. + * 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of + * its aggregate functions. `processRow` is the function to handle an input. `generateOutput` + * is used to generate result. + */ +abstract class AggregationIterator( + groupingKeyAttributes: Seq[Attribute], + valueAttributes: Seq[Attribute], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + outputsUnsafeRows: Boolean) + extends Iterator[InternalRow] with Logging { + + /////////////////////////////////////////////////////////////////////////// + // Initializing functions. + /////////////////////////////////////////////////////////////////////////// + + // An Seq of all AggregateExpressions. + // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final + // are at the beginning of the allAggregateExpressions. + protected val allAggregateExpressions = + nonCompleteAggregateExpressions ++ completeAggregateExpressions + + require( + allAggregateExpressions.map(_.mode).distinct.length <= 2, + s"$allAggregateExpressions are not supported becuase they have more than 2 distinct modes.") + + /** + * The distinct modes of AggregateExpressions. Right now, we can handle the following mode: + * - Partial-only: all AggregateExpressions have the mode of Partial; + * - PartialMerge-only: all AggregateExpressions have the mode of PartialMerge); + * - Final-only: all AggregateExpressions have the mode of Final; + * - Final-Complete: some AggregateExpressions have the mode of Final and + * others have the mode of Complete; + * - Complete-only: nonCompleteAggregateExpressions is empty and we have AggregateExpressions + * with mode Complete in completeAggregateExpressions; and + * - Grouping-only: there is no AggregateExpression. + */ + protected val aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = + nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> + completeAggregateExpressions.map(_.mode).distinct.headOption + + // Initialize all AggregateFunctions by binding references if necessary, + // and set inputBufferOffset and mutableBufferOffset. + protected val allAggregateFunctions: Array[AggregateFunction2] = { + var mutableBufferOffset = 0 + var inputBufferOffset: Int = initialInputBufferOffset + val functions = new Array[AggregateFunction2](allAggregateExpressions.length) + var i = 0 + while (i < allAggregateExpressions.length) { + val func = allAggregateExpressions(i).aggregateFunction + val funcWithBoundReferences = allAggregateExpressions(i).mode match { + case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] => + // We need to create BoundReferences if the function is not an + // AlgebraicAggregate (it does not support code-gen) and the mode of + // this function is Partial or Complete because we will call eval of this + // function's children in the update method of this aggregate function. + // Those eval calls require BoundReferences to work. + BindReferences.bindReference(func, valueAttributes) + case _ => + // We only need to set inputBufferOffset for aggregate functions with mode + // PartialMerge and Final. + func.withNewInputBufferOffset(inputBufferOffset) + inputBufferOffset += func.bufferSchema.length + func + } + // Set mutableBufferOffset for this function. It is important that setting + // mutableBufferOffset happens after all potential bindReference operations + // because bindReference will create a new instance of the function. + funcWithBoundReferences.withNewMutableBufferOffset(mutableBufferOffset) + mutableBufferOffset += funcWithBoundReferences.bufferSchema.length + functions(i) = funcWithBoundReferences + i += 1 + } + functions + } + + // Positions of those non-algebraic aggregate functions in allAggregateFunctions. + // For example, we have func1, func2, func3, func4 in aggregateFunctions, and + // func2 and func3 are non-algebraic aggregate functions. + // nonAlgebraicAggregateFunctionPositions will be [1, 2]. + private[this] val allNonAlgebraicAggregateFunctionPositions: Array[Int] = { + val positions = new ArrayBuffer[Int]() + var i = 0 + while (i < allAggregateFunctions.length) { + allAggregateFunctions(i) match { + case agg: AlgebraicAggregate => + case _ => positions += i + } + i += 1 + } + positions.toArray + } + + // All AggregateFunctions functions with mode Partial, PartialMerge, or Final. + private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction2] = + allAggregateFunctions.take(nonCompleteAggregateExpressions.length) + + // All non-algebraic aggregate functions with mode Partial, PartialMerge, or Final. + private[this] val nonCompleteNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = + nonCompleteAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + } + + // The projection used to initialize buffer values for all AlgebraicAggregates. + private[this] val algebraicInitialProjection = { + val initExpressions = allAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.initialValues + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(initExpressions, Nil)() + } + + // All non-Algebraic AggregateFunctions. + private[this] val allNonAlgebraicAggregateFunctions = + allNonAlgebraicAggregateFunctionPositions.map(allAggregateFunctions) + + /////////////////////////////////////////////////////////////////////////// + // Methods and fields used by sub-classes. + /////////////////////////////////////////////////////////////////////////// + + // Initializing functions used to process a row. + protected val processRow: (MutableRow, InternalRow) => Unit = { + val rowToBeProcessed = new JoinedRow + val aggregationBufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + aggregationMode match { + // Partial-only + case (Some(Partial), None) => + val updateExpressions = nonCompleteAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + val algebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() + + (currentBuffer: MutableRow, row: InternalRow) => { + algebraicUpdateProjection.target(currentBuffer) + // Process all algebraic aggregate functions. + algebraicUpdateProjection(rowToBeProcessed(currentBuffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonCompleteNonAlgebraicAggregateFunctions.length) { + nonCompleteNonAlgebraicAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } + } + + // PartialMerge-only or Final-only + case (Some(PartialMerge), None) | (Some(Final), None) => + val inputAggregationBufferSchema = if (initialInputBufferOffset == 0) { + // If initialInputBufferOffset, the input value does not contain + // grouping keys. + // This part is pretty hacky. + allAggregateFunctions.flatMap(_.cloneBufferAttributes).toSeq + } else { + groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.cloneBufferAttributes) + } + // val inputAggregationBufferSchema = + // groupingKeyAttributes ++ + // allAggregateFunctions.flatMap(_.cloneBufferAttributes) + val mergeExpressions = nonCompleteAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + // This projection is used to merge buffer values for all AlgebraicAggregates. + val algebraicMergeProjection = + newMutableProjection( + mergeExpressions, + aggregationBufferSchema ++ inputAggregationBufferSchema)() + + (currentBuffer: MutableRow, row: InternalRow) => { + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonCompleteNonAlgebraicAggregateFunctions.length) { + nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row) + i += 1 + } + } + + // Final-Complete + case (Some(Final), Some(Complete)) => + val completeAggregateFunctions: Array[AggregateFunction2] = + allAggregateFunctions.takeRight(completeAggregateExpressions.length) + // All non-algebraic aggregate functions with mode Complete. + val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = + completeAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + } + + // The first initialInputBufferOffset values of the input aggregation buffer is + // for grouping expressions and distinct columns. + val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset) + + val completeOffsetExpressions = + Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + // We do not touch buffer values of aggregate functions with the Final mode. + val finalOffsetExpressions = + Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + + val mergeInputSchema = + aggregationBufferSchema ++ + groupingAttributesAndDistinctColumns ++ + nonCompleteAggregateFunctions.flatMap(_.cloneBufferAttributes) + val mergeExpressions = + nonCompleteAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } ++ completeOffsetExpressions + val finalAlgebraicMergeProjection = + newMutableProjection(mergeExpressions, mergeInputSchema)() + + val updateExpressions = + finalOffsetExpressions ++ completeAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + val completeAlgebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() + + (currentBuffer: MutableRow, row: InternalRow) => { + val input = rowToBeProcessed(currentBuffer, row) + // For all aggregate functions with mode Complete, update buffers. + completeAlgebraicUpdateProjection.target(currentBuffer)(input) + var i = 0 + while (i < completeNonAlgebraicAggregateFunctions.length) { + completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } + + // For all aggregate functions with mode Final, merge buffers. + finalAlgebraicMergeProjection.target(currentBuffer)(input) + i = 0 + while (i < nonCompleteNonAlgebraicAggregateFunctions.length) { + nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row) + i += 1 + } + } + + // Complete-only + case (None, Some(Complete)) => + val completeAggregateFunctions: Array[AggregateFunction2] = + allAggregateFunctions.takeRight(completeAggregateExpressions.length) + // All non-algebraic aggregate functions with mode Complete. + val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = + completeAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + } + + val updateExpressions = + completeAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + val completeAlgebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() + + (currentBuffer: MutableRow, row: InternalRow) => { + val input = rowToBeProcessed(currentBuffer, row) + // For all aggregate functions with mode Complete, update buffers. + completeAlgebraicUpdateProjection.target(currentBuffer)(input) + var i = 0 + while (i < completeNonAlgebraicAggregateFunctions.length) { + completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } + } + + // Grouping only. + case (None, None) => (currentBuffer: MutableRow, row: InternalRow) => {} + + case other => + sys.error( + s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + + s"support evaluate modes $other in this iterator.") + } + } + + // Initializing the function used to generate the output row. + protected val generateOutput: (InternalRow, MutableRow) => InternalRow = { + val rowToBeEvaluated = new JoinedRow + val safeOutoutRow = new GenericMutableRow(resultExpressions.length) + val mutableOutput = if (outputsUnsafeRows) { + UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutoutRow) + } else { + safeOutoutRow + } + + aggregationMode match { + // Partial-only or PartialMerge-only: every output row is basically the values of + // the grouping expressions and the corresponding aggregation buffer. + case (Some(Partial), None) | (Some(PartialMerge), None) => + // Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not + // support generic getter), we create a mutable projection to output the + // JoinedRow(currentGroupingKey, currentBuffer) + val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.bufferAttributes) + val resultProjection = + newMutableProjection( + groupingKeyAttributes ++ bufferSchema, + groupingKeyAttributes ++ bufferSchema)() + resultProjection.target(mutableOutput) + + (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { + resultProjection(rowToBeEvaluated(currentGroupingKey, currentBuffer)) + // rowToBeEvaluated(currentGroupingKey, currentBuffer) + } + + // Final-only, Complete-only and Final-Complete: every output row contains values representing + // resultExpressions. + case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => + val bufferSchemata = + allAggregateFunctions.flatMap(_.bufferAttributes) + val evalExpressions = allAggregateFunctions.map { + case ae: AlgebraicAggregate => ae.evaluateExpression + case agg: AggregateFunction2 => NoOp + } + val algebraicEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() + val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes + // TODO: Use unsafe row. + val aggregateResult = new GenericMutableRow(aggregateResultSchema.length) + val resultProjection = + newMutableProjection( + resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)() + resultProjection.target(mutableOutput) + + (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { + // Generate results for all algebraic aggregate functions. + algebraicEvalProjection.target(aggregateResult)(currentBuffer) + // Generate results for all non-algebraic aggregate functions. + var i = 0 + while (i < allNonAlgebraicAggregateFunctions.length) { + aggregateResult.update( + allNonAlgebraicAggregateFunctionPositions(i), + allNonAlgebraicAggregateFunctions(i).eval(currentBuffer)) + i += 1 + } + resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult)) + } + + // Grouping-only: we only output values of grouping expressions. + case (None, None) => + val resultProjection = + newMutableProjection(resultExpressions, groupingKeyAttributes)() + resultProjection.target(mutableOutput) + + (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { + resultProjection(currentGroupingKey) + } + + case other => + sys.error( + s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + + s"support evaluate modes $other in this iterator.") + } + } + + /** Initializes buffer values for all aggregate functions. */ + protected def initializeBuffer(buffer: MutableRow): Unit = { + algebraicInitialProjection.target(buffer)(EmptyRow) + var i = 0 + while (i < allNonAlgebraicAggregateFunctions.length) { + allNonAlgebraicAggregateFunctions(i).initialize(buffer) + i += 1 + } + } + + /** + * Creates a new aggregation buffer and initializes buffer values + * for all aggregate functions. + */ + protected def newBuffer: MutableRow +} + +object AggregationIterator { + def kvIterator( + groupingExpressions: Seq[NamedExpression], + newProjection: (Seq[Expression], Seq[Attribute]) => Projection, + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]): KVIterator[InternalRow, InternalRow] = { + new KVIterator[InternalRow, InternalRow] { + private[this] val groupingKeyGenerator = newProjection(groupingExpressions, inputAttributes) + + private[this] var groupingKey: InternalRow = _ + + private[this] var value: InternalRow = _ + + override def next(): Boolean = { + if (inputIter.hasNext) { + // Read the next input row. + val inputRow = inputIter.next() + // Get groupingKey based on groupingExpressions. + groupingKey = groupingKeyGenerator(inputRow) + // The value is the inputRow. + value = inputRow + true + } else { + false + } + } + + override def getKey(): InternalRow = { + groupingKey + } + + override def getValue(): InternalRow = { + value + } + + override def close(): Unit = { + // Do nothing + } + } + } + + def unsafeKVIterator( + groupingExpressions: Seq[NamedExpression], + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]): KVIterator[UnsafeRow, InternalRow] = { + new KVIterator[UnsafeRow, InternalRow] { + private[this] val groupingKeyGenerator = + UnsafeProjection.create(groupingExpressions, inputAttributes) + + private[this] var groupingKey: UnsafeRow = _ + + private[this] var value: InternalRow = _ + + override def next(): Boolean = { + if (inputIter.hasNext) { + // Read the next input row. + val inputRow = inputIter.next() + // Get groupingKey based on groupingExpressions. + groupingKey = groupingKeyGenerator.apply(inputRow) + // The value is the inputRow. + value = inputRow + true + } else { + false + } + } + + override def getKey(): UnsafeRow = { + groupingKey + } + + override def getValue(): InternalRow = { + value + } + + override def close(): Unit = { + // Do nothing + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala new file mode 100644 index 0000000000..78bcee16c9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} +import org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.KVIterator + +/** + * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been + * sorted by values of [[groupingKeyAttributes]]. + */ +class SortBasedAggregationIterator( + groupingKeyAttributes: Seq[Attribute], + valueAttributes: Seq[Attribute], + inputKVIterator: KVIterator[InternalRow, InternalRow], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + outputsUnsafeRows: Boolean) + extends AggregationIterator( + groupingKeyAttributes, + valueAttributes, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + outputsUnsafeRows) { + + override protected def newBuffer: MutableRow = { + val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferRowSize: Int = bufferSchema.length + + val genericMutableBuffer = new GenericMutableRow(bufferRowSize) + val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isFixedLength) + + val buffer = if (useUnsafeBuffer) { + val unsafeProjection = + UnsafeProjection.create(bufferSchema.map(_.dataType)) + unsafeProjection.apply(genericMutableBuffer) + } else { + genericMutableBuffer + } + initializeBuffer(buffer) + buffer + } + + /////////////////////////////////////////////////////////////////////////// + // Mutable states for sort based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // The partition key of the current partition. + private[this] var currentGroupingKey: InternalRow = _ + + // The partition key of next partition. + private[this] var nextGroupingKey: InternalRow = _ + + // The first row of next partition. + private[this] var firstRowInNextGroup: InternalRow = _ + + // Indicates if we has new group of rows from the sorted input iterator + private[this] var sortedInputHasNewGroup: Boolean = false + + // The aggregation buffer used by the sort-based aggregation. + private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + + /** Processes rows in the current group. It will stop when it find a new group. */ + protected def processCurrentSortedGroup(): Unit = { + currentGroupingKey = nextGroupingKey + // Now, we will start to find all rows belonging to this group. + // We create a variable to track if we see the next group. + var findNextPartition = false + // firstRowInNextGroup is the first row of this group. We first process it. + processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + + // The search will stop when we see the next group or there is no + // input row left in the iter. + var hasNext = inputKVIterator.next() + while (!findNextPartition && hasNext) { + // Get the grouping key. + val groupingKey = inputKVIterator.getKey + val currentRow = inputKVIterator.getValue + + // Check if the current row belongs the current input row. + if (currentGroupingKey == groupingKey) { + processRow(sortBasedAggregationBuffer, currentRow) + + hasNext = inputKVIterator.next() + } else { + // We find a new group. + findNextPartition = true + nextGroupingKey = groupingKey.copy() + firstRowInNextGroup = currentRow.copy() + } + } + // We have not seen a new group. It means that there is no new row in the input + // iter. The current group is the last group of the iter. + if (!findNextPartition) { + sortedInputHasNewGroup = false + } + } + + /////////////////////////////////////////////////////////////////////////// + // Iterator's public methods + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = sortedInputHasNewGroup + + override final def next(): InternalRow = { + if (hasNext) { + // Process the current group. + processCurrentSortedGroup() + // Generate output row for the current group. + val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer) + // Initialize buffer values for the next group. + initializeBuffer(sortBasedAggregationBuffer) + + outputRow + } else { + // no more result + throw new NoSuchElementException + } + } + + protected def initialize(): Unit = { + if (inputKVIterator.next()) { + initializeBuffer(sortBasedAggregationBuffer) + + nextGroupingKey = inputKVIterator.getKey().copy() + firstRowInNextGroup = inputKVIterator.getValue().copy() + + sortedInputHasNewGroup = true + } else { + // This inputIter is empty. + sortedInputHasNewGroup = false + } + } + + initialize() + + def outputForEmptyGroupingKeyWithoutInput(): InternalRow = { + initializeBuffer(sortBasedAggregationBuffer) + generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer) + } +} + +object SortBasedAggregationIterator { + // scalastyle:off + def createFromInputIterator( + groupingExprs: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + newProjection: (Seq[Expression], Seq[Attribute]) => Projection, + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow], + outputsUnsafeRows: Boolean): SortBasedAggregationIterator = { + val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) { + AggregationIterator.unsafeKVIterator( + groupingExprs, + inputAttributes, + inputIter).asInstanceOf[KVIterator[InternalRow, InternalRow]] + } else { + AggregationIterator.kvIterator(groupingExprs, newProjection, inputAttributes, inputIter) + } + + new SortBasedAggregationIterator( + groupingExprs.map(_.toAttribute), + inputAttributes, + kvIterator, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + outputsUnsafeRows) + } + + def createFromKVIterator( + groupingKeyAttributes: Seq[Attribute], + valueAttributes: Seq[Attribute], + inputKVIterator: KVIterator[InternalRow, InternalRow], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + outputsUnsafeRows: Boolean): SortBasedAggregationIterator = { + new SortBasedAggregationIterator( + groupingKeyAttributes, + valueAttributes, + inputKVIterator, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + outputsUnsafeRows) + } + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala new file mode 100644 index 0000000000..37d34eb7cc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.execution.{UnsafeKeyValueSorter, UnsafeFixedWidthAggregationMap} +import org.apache.spark.unsafe.KVIterator +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.types.StructType + +/** + * An iterator used to evaluate [[AggregateFunction2]]. + * It first tries to use in-memory hash-based aggregation. If we cannot allocate more + * space for the hash map, we spill the sorted map entries, free the map, and then + * switch to sort-based aggregation. + */ +class UnsafeHybridAggregationIterator( + groupingKeyAttributes: Seq[Attribute], + valueAttributes: Seq[Attribute], + inputKVIterator: KVIterator[UnsafeRow, InternalRow], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + outputsUnsafeRows: Boolean) + extends AggregationIterator( + groupingKeyAttributes, + valueAttributes, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + outputsUnsafeRows) { + + require(groupingKeyAttributes.nonEmpty) + + /////////////////////////////////////////////////////////////////////////// + // Unsafe Aggregation buffers + /////////////////////////////////////////////////////////////////////////// + + // This is the Unsafe Aggregation Map used to store all buffers. + private[this] val buffers = new UnsafeFixedWidthAggregationMap( + newBuffer, + StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)), + StructType.fromAttributes(groupingKeyAttributes), + TaskContext.get.taskMemoryManager(), + SparkEnv.get.shuffleMemoryManager, + 1024 * 16, // initial capacity + SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"), + false // disable tracking of performance metrics + ) + + override protected def newBuffer: UnsafeRow = { + val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferRowSize: Int = bufferSchema.length + + val genericMutableBuffer = new GenericMutableRow(bufferRowSize) + val unsafeProjection = + UnsafeProjection.create(bufferSchema.map(_.dataType)) + val buffer = unsafeProjection.apply(genericMutableBuffer) + initializeBuffer(buffer) + buffer + } + + /////////////////////////////////////////////////////////////////////////// + // Methods and variables related to switching to sort-based aggregation + /////////////////////////////////////////////////////////////////////////// + private[this] var sortBased = false + + private[this] var sortBasedAggregationIterator: SortBasedAggregationIterator = _ + + // The value part of the input KV iterator is used to store original input values of + // aggregate functions, we need to convert them to aggregation buffers. + private def processOriginalInput( + firstKey: UnsafeRow, + firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = { + new KVIterator[UnsafeRow, UnsafeRow] { + private[this] var isFirstRow = true + + private[this] var groupingKey: UnsafeRow = _ + + private[this] val buffer: UnsafeRow = newBuffer + + override def next(): Boolean = { + initializeBuffer(buffer) + if (isFirstRow) { + isFirstRow = false + groupingKey = firstKey + processRow(buffer, firstValue) + + true + } else if (inputKVIterator.next()) { + groupingKey = inputKVIterator.getKey() + val value = inputKVIterator.getValue() + processRow(buffer, value) + + true + } else { + false + } + } + + override def getKey(): UnsafeRow = { + groupingKey + } + + override def getValue(): UnsafeRow = { + buffer + } + + override def close(): Unit = { + // Do nothing. + } + } + } + + // The value of the input KV Iterator has the format of groupingExprs + aggregation buffer. + // We need to project the aggregation buffer out. + private def projectInputBufferToUnsafe( + firstKey: UnsafeRow, + firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = { + new KVIterator[UnsafeRow, UnsafeRow] { + private[this] var isFirstRow = true + + private[this] var groupingKey: UnsafeRow = _ + + private[this] val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + + private[this] val value: UnsafeRow = { + val genericMutableRow = new GenericMutableRow(bufferSchema.length) + UnsafeProjection.create(bufferSchema.map(_.dataType)).apply(genericMutableRow) + } + + private[this] val projectInputBuffer = { + newMutableProjection(bufferSchema, valueAttributes)().target(value) + } + + override def next(): Boolean = { + if (isFirstRow) { + isFirstRow = false + groupingKey = firstKey + projectInputBuffer(firstValue) + + true + } else if (inputKVIterator.next()) { + groupingKey = inputKVIterator.getKey() + projectInputBuffer(inputKVIterator.getValue()) + + true + } else { + false + } + } + + override def getKey(): UnsafeRow = { + groupingKey + } + + override def getValue(): UnsafeRow = { + value + } + + override def close(): Unit = { + // Do nothing. + } + } + } + + /** + * We need to fall back to sort based aggregation because we do not have enough memory + * for our in-memory hash map (i.e. `buffers`). + */ + private def switchToSortBasedAggregation( + currentGroupingKey: UnsafeRow, + currentRow: InternalRow): Unit = { + logInfo("falling back to sort based aggregation.") + + // Step 1: Get the ExternalSorter containing entries of the map. + val externalSorter = buffers.destructAndCreateExternalSorter() + + // Step 2: Free the memory used by the map. + buffers.free() + + // Step 3: If we have aggregate function with mode Partial or Complete, + // we need to process them to get aggregation buffer. + // So, later in the sort-based aggregation iterator, we can do merge. + // If aggregate functions are with mode Final and PartialMerge, + // we just need to project the aggregation buffer from the input. + val needsProcess = aggregationMode match { + case (Some(Partial), None) => true + case (None, Some(Complete)) => true + case (Some(Final), Some(Complete)) => true + case _ => false + } + + val processedIterator = if (needsProcess) { + processOriginalInput(currentGroupingKey, currentRow) + } else { + // The input value's format is groupingExprs + buffer. + // We need to project the buffer part out. + projectInputBufferToUnsafe(currentGroupingKey, currentRow) + } + + // Step 4: Redirect processedIterator to externalSorter. + while (processedIterator.next()) { + externalSorter.insertKV(processedIterator.getKey(), processedIterator.getValue()) + } + + // Step 5: Get the sorted iterator from the externalSorter. + val sortedKVIterator: KVIterator[UnsafeRow, UnsafeRow] = externalSorter.sortedIterator() + + // Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator. + // For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator + // will be PartialMerge. For a aggregate function with mode Complete, + // its mode in the SortBasedAggregationIterator will be Final. + val newNonCompleteAggregateExpressions = allAggregateExpressions.map { + case AggregateExpression2(func, Partial, isDistinct) => + AggregateExpression2(func, PartialMerge, isDistinct) + case AggregateExpression2(func, Complete, isDistinct) => + AggregateExpression2(func, Final, isDistinct) + case other => other + } + val newNonCompleteAggregateAttributes = + nonCompleteAggregateAttributes ++ completeAggregateAttributes + + val newValueAttributes = + allAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + + sortBasedAggregationIterator = SortBasedAggregationIterator.createFromKVIterator( + groupingKeyAttributes = groupingKeyAttributes, + valueAttributes = newValueAttributes, + inputKVIterator = sortedKVIterator.asInstanceOf[KVIterator[InternalRow, InternalRow]], + nonCompleteAggregateExpressions = newNonCompleteAggregateExpressions, + nonCompleteAggregateAttributes = newNonCompleteAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = 0, + resultExpressions = resultExpressions, + newMutableProjection = newMutableProjection, + outputsUnsafeRows = outputsUnsafeRows) + } + + /////////////////////////////////////////////////////////////////////////// + // Methods used to initialize this iterator. + /////////////////////////////////////////////////////////////////////////// + + /** Starts to read input rows and falls back to sort-based aggregation if necessary. */ + protected def initialize(): Unit = { + var hasNext = inputKVIterator.next() + while (!sortBased && hasNext) { + val groupingKey = inputKVIterator.getKey() + val currentRow = inputKVIterator.getValue() + val buffer = buffers.getAggregationBuffer(groupingKey) + if (buffer == null) { + // buffer == null means that we could not allocate more memory. + // Now, we need to spill the map and switch to sort-based aggregation. + switchToSortBasedAggregation(groupingKey, currentRow) + sortBased = true + } else { + processRow(buffer, currentRow) + hasNext = inputKVIterator.next() + } + } + } + + // This is the starting point of this iterator. + initialize() + + // Creates the iterator for the Hash Aggregation Map after we have populated + // contents of that map. + private[this] val aggregationBufferMapIterator = buffers.iterator() + + private[this] var _mapIteratorHasNext = false + + // Pre-load the first key-value pair from the map to make hasNext idempotent. + if (!sortBased) { + _mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!_mapIteratorHasNext) { + buffers.free() + } + } + + /////////////////////////////////////////////////////////////////////////// + // Iterator's public methods + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = { + (sortBased && sortBasedAggregationIterator.hasNext) || (!sortBased && _mapIteratorHasNext) + } + + + override final def next(): InternalRow = { + if (hasNext) { + if (sortBased) { + sortBasedAggregationIterator.next() + } else { + // We did not fall back to the sort-based aggregation. + val result = + generateOutput( + aggregationBufferMapIterator.getKey, + aggregationBufferMapIterator.getValue) + // Pre-load next key-value pair form aggregationBufferMapIterator. + _mapIteratorHasNext = aggregationBufferMapIterator.next() + + if (!_mapIteratorHasNext) { + val resultCopy = result.copy() + buffers.free() + resultCopy + } else { + result + } + } + } else { + // no more result + throw new NoSuchElementException + } + } +} + +object UnsafeHybridAggregationIterator { + // scalastyle:off + def createFromInputIterator( + groupingExprs: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow], + outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = { + new UnsafeHybridAggregationIterator( + groupingExprs.map(_.toAttribute), + inputAttributes, + AggregationIterator.unsafeKVIterator(groupingExprs, inputAttributes, inputIter), + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + outputsUnsafeRows) + } + + def createFromKVIterator( + groupingKeyAttributes: Seq[Attribute], + valueAttributes: Seq[Attribute], + inputKVIterator: KVIterator[UnsafeRow, InternalRow], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = { + new UnsafeHybridAggregationIterator( + groupingKeyAttributes, + valueAttributes, + inputKVIterator, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + outputsUnsafeRows) + } + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala deleted file mode 100644 index 98538c462b..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} - -case class Aggregate2Sort( - requiredChildDistributionExpressions: Option[Seq[Expression]], - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - aggregateAttributes: Seq[Attribute], - resultExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - override def canProcessUnsafeRows: Boolean = true - - override def references: AttributeSet = { - val referencesInResults = - AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes) - - AttributeSet( - groupingExpressions.flatMap(_.references) ++ - aggregateExpressions.flatMap(_.references) ++ - referencesInResults) - } - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { - // TODO: We should not sort the input rows if they are just in reversed order. - groupingExpressions.map(SortOrder(_, Ascending)) :: Nil - } - - override def outputOrdering: Seq[SortOrder] = { - // It is possible that the child.outputOrdering starts with the required - // ordering expressions (e.g. we require [a] as the sort expression and the - // child's outputOrdering is [a, b]). We can only guarantee the output rows - // are sorted by values of groupingExpressions. - groupingExpressions.map(SortOrder(_, Ascending)) - } - - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - if (aggregateExpressions.length == 0) { - new FinalSortAggregationIterator( - groupingExpressions, - Nil, - Nil, - resultExpressions, - newMutableProjection, - child.output, - iter) - } else { - val aggregationIterator: SortAggregationIterator = { - aggregateExpressions.map(_.mode).distinct.toList match { - case Partial :: Nil => - new PartialSortAggregationIterator( - groupingExpressions, - aggregateExpressions, - newMutableProjection, - child.output, - iter) - case PartialMerge :: Nil => - new PartialMergeSortAggregationIterator( - groupingExpressions, - aggregateExpressions, - newMutableProjection, - child.output, - iter) - case Final :: Nil => - new FinalSortAggregationIterator( - groupingExpressions, - aggregateExpressions, - aggregateAttributes, - resultExpressions, - newMutableProjection, - child.output, - iter) - case other => - sys.error( - s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " + - s"modes $other in this operator.") - } - } - - aggregationIterator - } - } - } -} - -case class FinalAndCompleteAggregate2Sort( - previousGroupingExpressions: Seq[NamedExpression], - groupingExpressions: Seq[NamedExpression], - finalAggregateExpressions: Seq[AggregateExpression2], - finalAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - resultExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - override def references: AttributeSet = { - val referencesInResults = - AttributeSet(resultExpressions.flatMap(_.references)) -- - AttributeSet(finalAggregateExpressions) -- - AttributeSet(completeAggregateExpressions) - - AttributeSet( - groupingExpressions.flatMap(_.references) ++ - finalAggregateExpressions.flatMap(_.references) ++ - completeAggregateExpressions.flatMap(_.references) ++ - referencesInResults) - } - - override def requiredChildDistribution: List[Distribution] = { - if (groupingExpressions.isEmpty) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - groupingExpressions.map(SortOrder(_, Ascending)) :: Nil - - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - - new FinalAndCompleteSortAggregationIterator( - previousGroupingExpressions.length, - groupingExpressions, - finalAggregateExpressions, - finalAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - resultExpressions, - newMutableProjection, - child.output, - iter) - } - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala deleted file mode 100644 index 2ca0cb82c1..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala +++ /dev/null @@ -1,664 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.types.NullType - -import scala.collection.mutable.ArrayBuffer - -/** - * An iterator used to evaluate aggregate functions. It assumes that input rows - * are already grouped by values of `groupingExpressions`. - */ -private[sql] abstract class SortAggregationIterator( - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]) - extends Iterator[InternalRow] { - - /////////////////////////////////////////////////////////////////////////// - // Static fields for this iterator - /////////////////////////////////////////////////////////////////////////// - - protected val aggregateFunctions: Array[AggregateFunction2] = { - var mutableBufferOffset = 0 - var inputBufferOffset: Int = initialInputBufferOffset - val functions = new Array[AggregateFunction2](aggregateExpressions.length) - var i = 0 - while (i < aggregateExpressions.length) { - val func = aggregateExpressions(i).aggregateFunction - val funcWithBoundReferences = aggregateExpressions(i).mode match { - case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] => - // We need to create BoundReferences if the function is not an - // AlgebraicAggregate (it does not support code-gen) and the mode of - // this function is Partial or Complete because we will call eval of this - // function's children in the update method of this aggregate function. - // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, inputAttributes) - case _ => - // We only need to set inputBufferOffset for aggregate functions with mode - // PartialMerge and Final. - func.inputBufferOffset = inputBufferOffset - inputBufferOffset += func.bufferSchema.length - func - } - // Set mutableBufferOffset for this function. It is important that setting - // mutableBufferOffset happens after all potential bindReference operations - // because bindReference will create a new instance of the function. - funcWithBoundReferences.mutableBufferOffset = mutableBufferOffset - mutableBufferOffset += funcWithBoundReferences.bufferSchema.length - functions(i) = funcWithBoundReferences - i += 1 - } - functions - } - - // Positions of those non-algebraic aggregate functions in aggregateFunctions. - // For example, we have func1, func2, func3, func4 in aggregateFunctions, and - // func2 and func3 are non-algebraic aggregate functions. - // nonAlgebraicAggregateFunctionPositions will be [1, 2]. - protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = { - val positions = new ArrayBuffer[Int]() - var i = 0 - while (i < aggregateFunctions.length) { - aggregateFunctions(i) match { - case agg: AlgebraicAggregate => - case _ => positions += i - } - i += 1 - } - positions.toArray - } - - // All non-algebraic aggregate functions. - protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = - nonAlgebraicAggregateFunctionPositions.map(aggregateFunctions) - - // This is used to project expressions for the grouping expressions. - protected val groupGenerator = - newMutableProjection(groupingExpressions, inputAttributes)() - - // The underlying buffer shared by all aggregate functions. - protected val buffer: MutableRow = { - // The number of elements of the underlying buffer of this operator. - // All aggregate functions are sharing this underlying buffer and they find their - // buffer values through bufferOffset. - // var size = 0 - // var i = 0 - // while (i < aggregateFunctions.length) { - // size += aggregateFunctions(i).bufferSchema.length - // i += 1 - // } - new GenericMutableRow(aggregateFunctions.map(_.bufferSchema.length).sum) - } - - protected val joinedRow = new JoinedRow - - // This projection is used to initialize buffer values for all AlgebraicAggregates. - protected val algebraicInitialProjection = { - val initExpressions = aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.initialValues - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - - newMutableProjection(initExpressions, Nil)().target(buffer) - } - - /////////////////////////////////////////////////////////////////////////// - // Mutable states - /////////////////////////////////////////////////////////////////////////// - - // The partition key of the current partition. - protected var currentGroupingKey: InternalRow = _ - // The partition key of next partition. - protected var nextGroupingKey: InternalRow = _ - // The first row of next partition. - protected var firstRowInNextGroup: InternalRow = _ - // Indicates if we has new group of rows to process. - protected var hasNewGroup: Boolean = true - - /** Initializes buffer values for all aggregate functions. */ - protected def initializeBuffer(): Unit = { - algebraicInitialProjection(EmptyRow) - var i = 0 - while (i < nonAlgebraicAggregateFunctions.length) { - nonAlgebraicAggregateFunctions(i).initialize(buffer) - i += 1 - } - } - - protected def initialize(): Unit = { - if (inputIter.hasNext) { - initializeBuffer() - val currentRow = inputIter.next().copy() - // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, - // we are making a copy at here. - nextGroupingKey = groupGenerator(currentRow).copy() - firstRowInNextGroup = currentRow - } else { - // This iter is an empty one. - hasNewGroup = false - } - } - - /////////////////////////////////////////////////////////////////////////// - // Private methods - /////////////////////////////////////////////////////////////////////////// - - /** Processes rows in the current group. It will stop when it find a new group. */ - private def processCurrentGroup(): Unit = { - currentGroupingKey = nextGroupingKey - // Now, we will start to find all rows belonging to this group. - // We create a variable to track if we see the next group. - var findNextPartition = false - // firstRowInNextGroup is the first row of this group. We first process it. - processRow(firstRowInNextGroup) - // The search will stop when we see the next group or there is no - // input row left in the iter. - while (inputIter.hasNext && !findNextPartition) { - val currentRow = inputIter.next() - // Get the grouping key based on the grouping expressions. - // For the below compare method, we do not need to make a copy of groupingKey. - val groupingKey = groupGenerator(currentRow) - // Check if the current row belongs the current input row. - if (currentGroupingKey == groupingKey) { - processRow(currentRow) - } else { - // We find a new group. - findNextPartition = true - nextGroupingKey = groupingKey.copy() - firstRowInNextGroup = currentRow.copy() - } - } - // We have not seen a new group. It means that there is no new row in the input - // iter. The current group is the last group of the iter. - if (!findNextPartition) { - hasNewGroup = false - } - } - - /////////////////////////////////////////////////////////////////////////// - // Public methods - /////////////////////////////////////////////////////////////////////////// - - override final def hasNext: Boolean = hasNewGroup - - override final def next(): InternalRow = { - if (hasNext) { - // Process the current group. - processCurrentGroup() - // Generate output row for the current group. - val outputRow = generateOutput() - // Initilize buffer values for the next group. - initializeBuffer() - - outputRow - } else { - // no more result - throw new NoSuchElementException - } - } - - /////////////////////////////////////////////////////////////////////////// - // Methods that need to be implemented - /////////////////////////////////////////////////////////////////////////// - - /** The initial input buffer offset for `inputBufferOffset` of an [[AggregateFunction2]]. */ - protected def initialInputBufferOffset: Int - - /** The function used to process an input row. */ - protected def processRow(row: InternalRow): Unit - - /** The function used to generate the result row. */ - protected def generateOutput(): InternalRow - - /////////////////////////////////////////////////////////////////////////// - // Initialize this iterator - /////////////////////////////////////////////////////////////////////////// - - initialize() -} - -/** - * An iterator used to do partial aggregations (for those aggregate functions with mode Partial). - * It assumes that input rows are already grouped by values of `groupingExpressions`. - * The format of its output rows is: - * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| - */ -class PartialSortAggregationIterator( - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]) - extends SortAggregationIterator( - groupingExpressions, - aggregateExpressions, - newMutableProjection, - inputAttributes, - inputIter) { - - // This projection is used to update buffer values for all AlgebraicAggregates. - private val algebraicUpdateProjection = { - val bufferSchema = aggregateFunctions.flatMap(_.bufferAttributes) - val updateExpressions = aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) - } - - override protected def initialInputBufferOffset: Int = 0 - - override protected def processRow(row: InternalRow): Unit = { - // Process all algebraic aggregate functions. - algebraicUpdateProjection(joinedRow(buffer, row)) - // Process all non-algebraic aggregate functions. - var i = 0 - while (i < nonAlgebraicAggregateFunctions.length) { - nonAlgebraicAggregateFunctions(i).update(buffer, row) - i += 1 - } - } - - override protected def generateOutput(): InternalRow = { - // We just output the grouping expressions and the underlying buffer. - joinedRow(currentGroupingKey, buffer).copy() - } -} - -/** - * An iterator used to do partial merge aggregations (for those aggregate functions with mode - * PartialMerge). It assumes that input rows are already grouped by values of - * `groupingExpressions`. - * The format of its input rows is: - * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| - * - * The format of its internal buffer is: - * |aggregationBuffer1|...|aggregationBufferN| - * - * The format of its output rows is: - * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| - */ -class PartialMergeSortAggregationIterator( - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]) - extends SortAggregationIterator( - groupingExpressions, - aggregateExpressions, - newMutableProjection, - inputAttributes, - inputIter) { - - // This projection is used to merge buffer values for all AlgebraicAggregates. - private val algebraicMergeProjection = { - val mergeInputSchema = - aggregateFunctions.flatMap(_.bufferAttributes) ++ - groupingExpressions.map(_.toAttribute) ++ - aggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - - newMutableProjection(mergeExpressions, mergeInputSchema)() - } - - override protected def initialInputBufferOffset: Int = groupingExpressions.length - - override protected def processRow(row: InternalRow): Unit = { - // Process all algebraic aggregate functions. - algebraicMergeProjection.target(buffer)(joinedRow(buffer, row)) - // Process all non-algebraic aggregate functions. - var i = 0 - while (i < nonAlgebraicAggregateFunctions.length) { - nonAlgebraicAggregateFunctions(i).merge(buffer, row) - i += 1 - } - } - - override protected def generateOutput(): InternalRow = { - // We output grouping expressions and aggregation buffers. - joinedRow(currentGroupingKey, buffer).copy() - } -} - -/** - * An iterator used to do final aggregations (for those aggregate functions with mode - * Final). It assumes that input rows are already grouped by values of - * `groupingExpressions`. - * The format of its input rows is: - * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| - * - * The format of its internal buffer is: - * |aggregationBuffer1|...|aggregationBufferN| - * - * The format of its output rows is represented by the schema of `resultExpressions`. - */ -class FinalSortAggregationIterator( - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - aggregateAttributes: Seq[Attribute], - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]) - extends SortAggregationIterator( - groupingExpressions, - aggregateExpressions, - newMutableProjection, - inputAttributes, - inputIter) { - - // The result of aggregate functions. - private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length) - - // The projection used to generate the output rows of this operator. - // This is only used when we are generating final results of aggregate functions. - private val resultProjection = - newMutableProjection( - resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)() - - // This projection is used to merge buffer values for all AlgebraicAggregates. - private val algebraicMergeProjection = { - val mergeInputSchema = - aggregateFunctions.flatMap(_.bufferAttributes) ++ - groupingExpressions.map(_.toAttribute) ++ - aggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - - newMutableProjection(mergeExpressions, mergeInputSchema)() - } - - // This projection is used to evaluate all AlgebraicAggregates. - private val algebraicEvalProjection = { - val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) - val evalExpressions = aggregateFunctions.map { - case ae: AlgebraicAggregate => ae.evaluateExpression - case agg: AggregateFunction2 => NoOp - } - - newMutableProjection(evalExpressions, bufferSchemata)() - } - - override protected def initialInputBufferOffset: Int = groupingExpressions.length - - override def initialize(): Unit = { - if (inputIter.hasNext) { - initializeBuffer() - val currentRow = inputIter.next().copy() - // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, - // we are making a copy at here. - nextGroupingKey = groupGenerator(currentRow).copy() - firstRowInNextGroup = currentRow - } else { - if (groupingExpressions.isEmpty) { - // If there is no grouping expression, we need to generate a single row as the output. - initializeBuffer() - // Right now, the buffer only contains initial buffer values. Because - // merging two buffers with initial values will generate a row that - // still store initial values. We set the currentRow as the copy of the current buffer. - // Because input aggregation buffer has initialInputBufferOffset extra values at the - // beginning, we create a dummy row for this part. - val currentRow = - joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy() - nextGroupingKey = groupGenerator(currentRow).copy() - firstRowInNextGroup = currentRow - } else { - // This iter is an empty one. - hasNewGroup = false - } - } - } - - override protected def processRow(row: InternalRow): Unit = { - // Process all algebraic aggregate functions. - algebraicMergeProjection.target(buffer)(joinedRow(buffer, row)) - // Process all non-algebraic aggregate functions. - var i = 0 - while (i < nonAlgebraicAggregateFunctions.length) { - nonAlgebraicAggregateFunctions(i).merge(buffer, row) - i += 1 - } - } - - override protected def generateOutput(): InternalRow = { - // Generate results for all algebraic aggregate functions. - algebraicEvalProjection.target(aggregateResult)(buffer) - // Generate results for all non-algebraic aggregate functions. - var i = 0 - while (i < nonAlgebraicAggregateFunctions.length) { - aggregateResult.update( - nonAlgebraicAggregateFunctionPositions(i), - nonAlgebraicAggregateFunctions(i).eval(buffer)) - i += 1 - } - resultProjection(joinedRow(currentGroupingKey, aggregateResult)) - } -} - -/** - * An iterator used to do both final aggregations (for those aggregate functions with mode - * Final) and complete aggregations (for those aggregate functions with mode Complete). - * It assumes that input rows are already grouped by values of `groupingExpressions`. - * The format of its input rows is: - * |groupingExpr1|...|groupingExprN|col1|...|colM|aggregationBuffer1|...|aggregationBufferN| - * col1 to colM are columns used by aggregate functions with Complete mode. - * aggregationBuffer1 to aggregationBufferN are buffers used by aggregate functions with - * Final mode. - * - * The format of its internal buffer is: - * |aggregationBuffer1|...|aggregationBuffer(N+M)| - * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with - * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode - * Complete. - * - * The format of its output rows is represented by the schema of `resultExpressions`. - */ -class FinalAndCompleteSortAggregationIterator( - override protected val initialInputBufferOffset: Int, - groupingExpressions: Seq[NamedExpression], - finalAggregateExpressions: Seq[AggregateExpression2], - finalAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]) - extends SortAggregationIterator( - groupingExpressions, - // TODO: document the ordering - finalAggregateExpressions ++ completeAggregateExpressions, - newMutableProjection, - inputAttributes, - inputIter) { - - // The result of aggregate functions. - private val aggregateResult: MutableRow = - new GenericMutableRow(completeAggregateAttributes.length + finalAggregateAttributes.length) - - // The projection used to generate the output rows of this operator. - // This is only used when we are generating final results of aggregate functions. - private val resultProjection = { - val inputSchema = - groupingExpressions.map(_.toAttribute) ++ - finalAggregateAttributes ++ - completeAggregateAttributes - newMutableProjection(resultExpressions, inputSchema)() - } - - // All aggregate functions with mode Final. - private val finalAggregateFunctions: Array[AggregateFunction2] = { - val functions = new Array[AggregateFunction2](finalAggregateExpressions.length) - var i = 0 - while (i < finalAggregateExpressions.length) { - functions(i) = aggregateFunctions(i) - i += 1 - } - functions - } - - // All non-algebraic aggregate functions with mode Final. - private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = - finalAggregateFunctions.collect { - case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - } - - // All aggregate functions with mode Complete. - private val completeAggregateFunctions: Array[AggregateFunction2] = { - val functions = new Array[AggregateFunction2](completeAggregateExpressions.length) - var i = 0 - while (i < completeAggregateExpressions.length) { - functions(i) = aggregateFunctions(finalAggregateFunctions.length + i) - i += 1 - } - functions - } - - // All non-algebraic aggregate functions with mode Complete. - private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = - completeAggregateFunctions.collect { - case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - } - - // This projection is used to merge buffer values for all AlgebraicAggregates with mode - // Final. - private val finalAlgebraicMergeProjection = { - // The first initialInputBufferOffset values of the input aggregation buffer is - // for grouping expressions and distinct columns. - val groupingAttributesAndDistinctColumns = inputAttributes.take(initialInputBufferOffset) - - val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) - - val mergeInputSchema = - finalAggregateFunctions.flatMap(_.bufferAttributes) ++ - completeAggregateFunctions.flatMap(_.bufferAttributes) ++ - groupingAttributesAndDistinctColumns ++ - finalAggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = - finalAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } ++ completeOffsetExpressions - newMutableProjection(mergeExpressions, mergeInputSchema)() - } - - // This projection is used to update buffer values for all AlgebraicAggregates with mode - // Complete. - private val completeAlgebraicUpdateProjection = { - // We do not touch buffer values of aggregate functions with the Final mode. - val finalOffsetExpressions = - Seq.fill(finalAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) - - val bufferSchema = - finalAggregateFunctions.flatMap(_.bufferAttributes) ++ - completeAggregateFunctions.flatMap(_.bufferAttributes) - val updateExpressions = - finalOffsetExpressions ++ completeAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) - } - - // This projection is used to evaluate all AlgebraicAggregates. - private val algebraicEvalProjection = { - val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) - val evalExpressions = aggregateFunctions.map { - case ae: AlgebraicAggregate => ae.evaluateExpression - case agg: AggregateFunction2 => NoOp - } - - newMutableProjection(evalExpressions, bufferSchemata)() - } - - override def initialize(): Unit = { - if (inputIter.hasNext) { - initializeBuffer() - val currentRow = inputIter.next().copy() - // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, - // we are making a copy at here. - nextGroupingKey = groupGenerator(currentRow).copy() - firstRowInNextGroup = currentRow - } else { - if (groupingExpressions.isEmpty) { - // If there is no grouping expression, we need to generate a single row as the output. - initializeBuffer() - // Right now, the buffer only contains initial buffer values. Because - // merging two buffers with initial values will generate a row that - // still store initial values. We set the currentRow as the copy of the current buffer. - // Because input aggregation buffer has initialInputBufferOffset extra values at the - // beginning, we create a dummy row for this part. - val currentRow = - joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy() - nextGroupingKey = groupGenerator(currentRow).copy() - firstRowInNextGroup = currentRow - } else { - // This iter is an empty one. - hasNewGroup = false - } - } - } - - override protected def processRow(row: InternalRow): Unit = { - val input = joinedRow(buffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeAlgebraicUpdateProjection(input) - var i = 0 - while (i < completeNonAlgebraicAggregateFunctions.length) { - completeNonAlgebraicAggregateFunctions(i).update(buffer, row) - i += 1 - } - - // For all aggregate functions with mode Final, merge buffers. - finalAlgebraicMergeProjection.target(buffer)(input) - i = 0 - while (i < finalNonAlgebraicAggregateFunctions.length) { - finalNonAlgebraicAggregateFunctions(i).merge(buffer, row) - i += 1 - } - } - - override protected def generateOutput(): InternalRow = { - // Generate results for all algebraic aggregate functions. - algebraicEvalProjection.target(aggregateResult)(buffer) - // Generate results for all non-algebraic aggregate functions. - var i = 0 - while (i < nonAlgebraicAggregateFunctions.length) { - aggregateResult.update( - nonAlgebraicAggregateFunctionPositions(i), - nonAlgebraicAggregateFunctions(i).eval(buffer)) - i += 1 - } - - resultProjection(joinedRow(currentGroupingKey, aggregateResult)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index cc54319171..5fafc916bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -24,7 +24,154 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjecti import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2 import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} -import org.apache.spark.sql.types.{Metadata, StructField, StructType, DataType} +import org.apache.spark.sql.types._ + +/** + * A helper trait used to create specialized setter and getter for types supported by + * [[org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap]]'s buffer. + * (see UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema). + */ +sealed trait BufferSetterGetterUtils { + + def createGetters(schema: StructType): Array[(InternalRow, Int) => Any] = { + val dataTypes = schema.fields.map(_.dataType) + val getters = new Array[(InternalRow, Int) => Any](dataTypes.length) + + var i = 0 + while (i < getters.length) { + getters(i) = dataTypes(i) match { + case BooleanType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal) + + case ByteType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getByte(ordinal) + + case ShortType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getShort(ordinal) + + case IntegerType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getInt(ordinal) + + case LongType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getLong(ordinal) + + case FloatType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getFloat(ordinal) + + case DoubleType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getDouble(ordinal) + + case dt: DecimalType => + val precision = dt.precision + val scale = dt.scale + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale) + + case other => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.get(ordinal, other) + } + + i += 1 + } + + getters + } + + def createSetters(schema: StructType): Array[((MutableRow, Int, Any) => Unit)] = { + val dataTypes = schema.fields.map(_.dataType) + val setters = new Array[(MutableRow, Int, Any) => Unit](dataTypes.length) + + var i = 0 + while (i < setters.length) { + setters(i) = dataTypes(i) match { + case b: BooleanType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setBoolean(ordinal, value.asInstanceOf[Boolean]) + } else { + row.setNullAt(ordinal) + } + + case ByteType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setByte(ordinal, value.asInstanceOf[Byte]) + } else { + row.setNullAt(ordinal) + } + + case ShortType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setShort(ordinal, value.asInstanceOf[Short]) + } else { + row.setNullAt(ordinal) + } + + case IntegerType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setInt(ordinal, value.asInstanceOf[Int]) + } else { + row.setNullAt(ordinal) + } + + case LongType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setLong(ordinal, value.asInstanceOf[Long]) + } else { + row.setNullAt(ordinal) + } + + case FloatType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setFloat(ordinal, value.asInstanceOf[Float]) + } else { + row.setNullAt(ordinal) + } + + case DoubleType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setDouble(ordinal, value.asInstanceOf[Double]) + } else { + row.setNullAt(ordinal) + } + + case dt: DecimalType => + val precision = dt.precision + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) + } else { + row.setNullAt(ordinal) + } + + case other => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.update(ordinal, value) + } else { + row.setNullAt(ordinal) + } + } + + i += 1 + } + + setters + } +} /** * A Mutable [[Row]] representing an mutable aggregation buffer. @@ -35,7 +182,7 @@ private[sql] class MutableAggregationBufferImpl ( toScalaConverters: Array[Any => Any], bufferOffset: Int, var underlyingBuffer: MutableRow) - extends MutableAggregationBuffer { + extends MutableAggregationBuffer with BufferSetterGetterUtils { private[this] val offsets: Array[Int] = { val newOffsets = new Array[Int](length) @@ -47,6 +194,10 @@ private[sql] class MutableAggregationBufferImpl ( newOffsets } + private[this] val bufferValueGetters = createGetters(schema) + + private[this] val bufferValueSetters = createSetters(schema) + override def length: Int = toCatalystConverters.length override def get(i: Int): Any = { @@ -54,7 +205,7 @@ private[sql] class MutableAggregationBufferImpl ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } - toScalaConverters(i)(underlyingBuffer.get(offsets(i), schema(i).dataType)) + toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i))) } def update(i: Int, value: Any): Unit = { @@ -62,7 +213,15 @@ private[sql] class MutableAggregationBufferImpl ( throw new IllegalArgumentException( s"Could not update ${i}th value in this buffer because it only has $length values.") } - underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value)) + + bufferValueSetters(i)(underlyingBuffer, offsets(i), toCatalystConverters(i)(value)) + } + + // Because get method call specialized getter based on the schema, we cannot use the + // default implementation of the isNullAt (which is get(i) == null). + // We have to override it to call isNullAt of the underlyingBuffer. + override def isNullAt(i: Int): Boolean = { + underlyingBuffer.isNullAt(offsets(i)) } override def copy(): MutableAggregationBufferImpl = { @@ -84,7 +243,7 @@ private[sql] class InputAggregationBuffer private[sql] ( toScalaConverters: Array[Any => Any], bufferOffset: Int, var underlyingInputBuffer: InternalRow) - extends Row { + extends Row with BufferSetterGetterUtils { private[this] val offsets: Array[Int] = { val newOffsets = new Array[Int](length) @@ -96,6 +255,10 @@ private[sql] class InputAggregationBuffer private[sql] ( newOffsets } + private[this] val bufferValueGetters = createGetters(schema) + + def getBufferOffset: Int = bufferOffset + override def length: Int = toCatalystConverters.length override def get(i: Int): Any = { @@ -103,8 +266,14 @@ private[sql] class InputAggregationBuffer private[sql] ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } - // TODO: Use buffer schema to avoid using generic getter. - toScalaConverters(i)(underlyingInputBuffer.get(offsets(i), schema(i).dataType)) + toScalaConverters(i)(bufferValueGetters(i)(underlyingInputBuffer, offsets(i))) + } + + // Because get method call specialized getter based on the schema, we cannot use the + // default implementation of the isNullAt (which is get(i) == null). + // We have to override it to call isNullAt of the underlyingInputBuffer. + override def isNullAt(i: Int): Boolean = { + underlyingInputBuffer.isNullAt(offsets(i)) } override def copy(): InputAggregationBuffer = { @@ -147,7 +316,7 @@ private[sql] case class ScalaUDAF( override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) - val childrenSchema: StructType = { + private[this] val childrenSchema: StructType = { val inputFields = children.zipWithIndex.map { case (child, index) => StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) @@ -155,7 +324,7 @@ private[sql] case class ScalaUDAF( StructType(inputFields) } - lazy val inputProjection = { + private lazy val inputProjection = { val inputAttributes = childrenSchema.toAttributes log.debug( s"Creating MutableProj: $children, inputSchema: $inputAttributes.") @@ -168,40 +337,68 @@ private[sql] case class ScalaUDAF( } } - val inputToScalaConverters: Any => Any = + private[this] val inputToScalaConverters: Any => Any = CatalystTypeConverters.createToScalaConverter(childrenSchema) - val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field => - CatalystTypeConverters.createToCatalystConverter(field.dataType) + private[this] val bufferValuesToCatalystConverters: Array[Any => Any] = { + bufferSchema.fields.map { field => + CatalystTypeConverters.createToCatalystConverter(field.dataType) + } } - val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field => - CatalystTypeConverters.createToScalaConverter(field.dataType) + private[this] val bufferValuesToScalaConverters: Array[Any => Any] = { + bufferSchema.fields.map { field => + CatalystTypeConverters.createToScalaConverter(field.dataType) + } } - lazy val inputAggregateBuffer: InputAggregationBuffer = - new InputAggregationBuffer( - bufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - inputBufferOffset, - null) - - lazy val mutableAggregateBuffer: MutableAggregationBufferImpl = - new MutableAggregationBufferImpl( - bufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - mutableBufferOffset, - null) + // This buffer is only used at executor side. + private[this] var inputAggregateBuffer: InputAggregationBuffer = null + + // This buffer is only used at executor side. + private[this] var mutableAggregateBuffer: MutableAggregationBufferImpl = null + + // This buffer is only used at executor side. + private[this] var evalAggregateBuffer: InputAggregationBuffer = null + + /** + * Sets the inputBufferOffset to newInputBufferOffset and then create a new instance of + * `inputAggregateBuffer` based on this new inputBufferOffset. + */ + override def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = { + super.withNewInputBufferOffset(newInputBufferOffset) + // inputBufferOffset has been updated. + inputAggregateBuffer = + new InputAggregationBuffer( + bufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + inputBufferOffset, + null) + } - lazy val evalAggregateBuffer: InputAggregationBuffer = - new InputAggregationBuffer( - bufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - mutableBufferOffset, - null) + /** + * Sets the mutableBufferOffset to newMutableBufferOffset and then create a new instance of + * `mutableAggregateBuffer` and `evalAggregateBuffer` based on this new mutableBufferOffset. + */ + override def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = { + super.withNewMutableBufferOffset(newMutableBufferOffset) + // mutableBufferOffset has been updated. + mutableAggregateBuffer = + new MutableAggregationBufferImpl( + bufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + mutableBufferOffset, + null) + evalAggregateBuffer = + new InputAggregationBuffer( + bufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + mutableBufferOffset, + null) + } override def initialize(buffer: MutableRow): Unit = { mutableAggregateBuffer.underlyingBuffer = buffer diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 03635baae4..960be08f84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -17,13 +17,9 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.types.{StructType, MapType, ArrayType} /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -52,13 +48,16 @@ object Utils { agg.aggregateFunction.bufferAttributes } val partialAggregate = - Aggregate2Sort( - None: Option[Seq[Expression]], - namedGroupingExpressions.map(_._2), - partialAggregateExpressions, - partialAggregateAttributes, - namedGroupingAttributes ++ partialAggregateAttributes, - child) + Aggregate( + requiredChildDistributionExpressions = None: Option[Seq[Expression]], + groupingExpressions = namedGroupingExpressions.map(_._2), + nonCompleteAggregateExpressions = partialAggregateExpressions, + nonCompleteAggregateAttributes = partialAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = 0, + resultExpressions = namedGroupingAttributes ++ partialAggregateAttributes, + child = child) // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) @@ -78,13 +77,17 @@ object Utils { }.getOrElse(expression) }.asInstanceOf[NamedExpression] } - val finalAggregate = Aggregate2Sort( - Some(namedGroupingAttributes), - namedGroupingAttributes, - finalAggregateExpressions, - finalAggregateAttributes, - rewrittenResultExpressions, - partialAggregate) + val finalAggregate = + Aggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + nonCompleteAggregateAttributes = finalAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = namedGroupingAttributes.length, + resultExpressions = rewrittenResultExpressions, + child = partialAggregate) finalAggregate :: Nil } @@ -133,14 +136,21 @@ object Utils { val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => agg.aggregateFunction.bufferAttributes } + val partialAggregateGroupingExpressions = + (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2) + val partialAggregateResult = + namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes val partialAggregate = - Aggregate2Sort( - None: Option[Seq[Expression]], - (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2), - partialAggregateExpressions, - partialAggregateAttributes, - namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes, - child) + Aggregate( + requiredChildDistributionExpressions = None: Option[Seq[Expression]], + groupingExpressions = partialAggregateGroupingExpressions, + nonCompleteAggregateExpressions = partialAggregateExpressions, + nonCompleteAggregateAttributes = partialAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = 0, + resultExpressions = partialAggregateResult, + child = child) // 2. Create an Aggregate Operator for partial merge aggregations. val partialMergeAggregateExpressions = functionsWithoutDistinct.map { @@ -151,14 +161,19 @@ object Utils { partialMergeAggregateExpressions.flatMap { agg => agg.aggregateFunction.bufferAttributes } + val partialMergeAggregateResult = + namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes val partialMergeAggregate = - Aggregate2Sort( - Some(namedGroupingAttributes), - namedGroupingAttributes ++ distinctColumnAttributes, - partialMergeAggregateExpressions, - partialMergeAggregateAttributes, - namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes, - partialAggregate) + Aggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, + nonCompleteAggregateExpressions = partialMergeAggregateExpressions, + nonCompleteAggregateAttributes = partialMergeAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = partialMergeAggregateResult, + child = partialAggregate) // 3. Create an Aggregate Operator for partial merge aggregations. val finalAggregateExpressions = functionsWithoutDistinct.map { @@ -199,15 +214,17 @@ object Utils { }.getOrElse(expression) }.asInstanceOf[NamedExpression] } - val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort( - namedGroupingAttributes ++ distinctColumnAttributes, - namedGroupingAttributes, - finalAggregateExpressions, - finalAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - rewrittenResultExpressions, - partialMergeAggregate) + val finalAndCompleteAggregate = + Aggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + nonCompleteAggregateAttributes = finalAggregateAttributes, + completeAggregateExpressions = completeAggregateExpressions, + completeAggregateAttributes = completeAggregateAttributes, + initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = rewrittenResultExpressions, + child = partialMergeAggregate) finalAndCompleteAggregate :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 2294a670c7..5a1b000e89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -220,7 +220,6 @@ case class TakeOrderedAndProject( override def outputOrdering: Seq[SortOrder] = sortOrder } - /** * :: DeveloperApi :: * Return a new RDD that has exactly `numPartitions` partitions. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 51fe9d9d98..bbadc202a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.scalatest.BeforeAndAfterAll - import java.sql.Timestamp +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException -import org.apache.spark.sql.execution.aggregate.Aggregate2Sort +import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ @@ -273,7 +273,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { var hasGeneratedAgg = false df.queryExecution.executedPlan.foreach { case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true - case newAggregate: Aggregate2Sort => hasGeneratedAgg = true + case newAggregate: aggregate.Aggregate => hasGeneratedAgg = true case _ => } if (!hasGeneratedAgg) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 54f82f89ed..7978ed57a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -138,7 +138,14 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll s"Expected $expectedSerializerClass as the serializer of Exchange. " + s"However, the serializer was not set." val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage)) - assert(serializer.getClass === expectedSerializerClass) + val isExpectedSerializer = + serializer.getClass == expectedSerializerClass || + serializer.getClass == classOf[UnsafeRowSerializer] + val wrongSerializerErrorMessage = + s"Expected ${expectedSerializerClass.getCanonicalName} or " + + s"${classOf[UnsafeRowSerializer].getCanonicalName}. But " + + s"${serializer.getClass.getCanonicalName} is used." + assert(isExpectedSerializer, wrongSerializerErrorMessage) case _ => // Ignore other nodes. } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 0375eb79ad..6f0db27775 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.execution.aggregate.Aggregate2Sort +import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row} import org.scalatest.BeforeAndAfterAll import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} -class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { +abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { override val sqlContext = TestHive import sqlContext.implicits._ @@ -34,7 +34,7 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf override def beforeAll(): Unit = { originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 - sqlContext.sql("set spark.sql.useAggregate2=true") + sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true") val data1 = Seq[(Integer, Integer)]( (1, 10), (null, -60), @@ -81,7 +81,7 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf sqlContext.sql("DROP TABLE IF EXISTS agg1") sqlContext.sql("DROP TABLE IF EXISTS agg2") sqlContext.dropTempTable("emptyTable") - sqlContext.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2") + sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) } test("empty table") { @@ -454,54 +454,86 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf } test("error handling") { - sqlContext.sql(s"set spark.sql.useAggregate2=false") - var errorMessage = intercept[AnalysisException] { - sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | mydoublesum(value), - | mydoubleavg(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + withSQLConf("spark.sql.useAggregate2" -> "false") { + val errorMessage = intercept[AnalysisException] { + sqlContext.sql( + """ + |SELECT + | key, + | sum(value + 1.5 * key), + | mydoublesum(value), + | mydoubleavg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).collect() + }.getMessage + assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + } // TODO: once we support Hive UDAF in the new interface, // we can remove the following two tests. - sqlContext.sql(s"set spark.sql.useAggregate2=true") - errorMessage = intercept[AnalysisException] { - sqlContext.sql( + withSQLConf("spark.sql.useAggregate2" -> "true") { + val errorMessage = intercept[AnalysisException] { + sqlContext.sql( + """ + |SELECT + | key, + | mydoublesum(value + 1.5 * key), + | stddev_samp(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).collect() + }.getMessage + assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + + // This will fall back to the old aggregate + val newAggregateOperators = sqlContext.sql( """ |SELECT | key, - | mydoublesum(value + 1.5 * key), + | sum(value + 1.5 * key), | stddev_samp(value) |FROM agg1 |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) - - // This will fall back to the old aggregate - val newAggregateOperators = sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | stddev_samp(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).queryExecution.executedPlan.collect { - case agg: Aggregate2Sort => agg + """.stripMargin).queryExecution.executedPlan.collect { + case agg: aggregate.Aggregate => agg + } + val message = + "We should fallback to the old aggregation code path if " + + "there is any aggregate function that cannot be converted to the new interface." + assert(newAggregateOperators.isEmpty, message) } - val message = - "We should fallback to the old aggregation code path if there is any aggregate function " + - "that cannot be converted to the new interface." - assert(newAggregateOperators.isEmpty, message) + } +} + +class SortBasedAggregationQuerySuite extends AggregationQuerySuite { - sqlContext.sql(s"set spark.sql.useAggregate2=true") + var originalUnsafeEnabled: Boolean = _ + + override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "false") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + } +} + +class TungstenAggregationQuerySuite extends AggregationQuerySuite { + + var originalUnsafeEnabled: Boolean = _ + + override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) } } -- GitLab