diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 88fb516e64aaf379141273542013f882c98df3eb..a73024d6adba10328df7962dedacbc16287c0868 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -31,8 +31,11 @@ case class Average(child: Expression) extends AlgebraicAggregate { override def dataType: DataType = resultType // Expected input data type. - // TODO: Once we remove the old code path, we can use our analyzer to cast NullType - // to the default data type of the NumericType. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select avg(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) private val resultType = child.dataType match { @@ -256,12 +259,19 @@ case class Sum(child: Expression) extends AlgebraicAggregate { override def dataType: DataType = resultType // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select sum(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) private val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) + // TODO: Remove this line once we remove the NullType from inputTypes. + case NullType => IntegerType case _ => child.dataType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index a730ffbb217c01b458cdafadb2af9421f4b746bf..c5aaebe6732252b68e934fd6682df55b0fc40ed3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -191,8 +191,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // aggregate function to the corresponding attribute of the function. val aggregateFunctionMap = aggregateExpressions.map { agg => val aggregateFunction = agg.aggregateFunction + val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute (aggregateFunction, agg.isDistinct) -> - Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + (aggregateFunction -> attribtue) }.toMap val (functionsWithDistinct, functionsWithoutDistinct) = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 16498da080c88128b5339286e4780fc425f0f9a2..39f8f992a9f009c915d2941167ddf7c6365c2450 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream} +import java.io._ import java.nio.ByteBuffer import scala.reflect.ClassTag @@ -58,11 +58,26 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst */ override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) + // When `out` is backed by ChainedBufferOutputStream, we will get an + // UnsupportedOperationException when we call dOut.writeInt because it internally calls + // ChainedBufferOutputStream's write(b: Int), which is not supported. + // To workaround this issue, we create an array for sorting the int value. + // To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and + // run SparkSqlSerializer2SortMergeShuffleSuite. + private[this] var intBuffer: Array[Byte] = new Array[Byte](4) private[this] val dOut: DataOutputStream = new DataOutputStream(out) override def writeValue[T: ClassTag](value: T): SerializationStream = { val row = value.asInstanceOf[UnsafeRow] - dOut.writeInt(row.getSizeInBytes) + val size = row.getSizeInBytes + // This part is based on DataOutputStream's writeInt. + // It is for dOut.writeInt(row.getSizeInBytes). + intBuffer(0) = ((size >>> 24) & 0xFF).toByte + intBuffer(1) = ((size >>> 16) & 0xFF).toByte + intBuffer(2) = ((size >>> 8) & 0xFF).toByte + intBuffer(3) = ((size >>> 0) & 0xFF).toByte + dOut.write(intBuffer, 0, 4) + row.writeToStream(out, writeBuffer) this } @@ -90,6 +105,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def close(): Unit = { writeBuffer = null + intBuffer = null dOut.writeInt(EOF) dOut.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala deleted file mode 100644 index cf568dc048674f2c01ecf9c849df46c97fcb17c2..0000000000000000000000000000000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} -import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} -import org.apache.spark.sql.types.StructType - -/** - * An Aggregate Operator used to evaluate [[AggregateFunction2]]. Based on the data types - * of the grouping expressions and aggregate functions, it determines if it uses - * sort-based aggregation and hybrid (hash-based with sort-based as the fallback) to - * process input rows. - */ -case class Aggregate( - requiredChildDistributionExpressions: Option[Seq[Expression]], - groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - private[this] val allAggregateExpressions = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - private[this] val hasNonAlgebricAggregateFunctions = - !allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) - - // Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of - // grouping key and aggregation buffer is supported; and (3) all - // aggregate functions are algebraic. - private[this] val supportsHybridIterator: Boolean = { - val aggregationBufferSchema: StructType = - StructType.fromAttributes( - allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) - val groupKeySchema: StructType = - StructType.fromAttributes(groupingExpressions.map(_.toAttribute)) - - val schemaSupportsUnsafe: Boolean = - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeProjection.canSupport(groupKeySchema) - - // TODO: Use the hybrid iterator for non-algebric aggregate functions. - sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && !hasNonAlgebricAggregateFunctions - } - - // We need to use sorted input if we have grouping expressions, and - // we cannot use the hybrid iterator or the hybrid is disabled. - private[this] val requiresSortedInput: Boolean = { - groupingExpressions.nonEmpty && !supportsHybridIterator - } - - override def canProcessUnsafeRows: Boolean = !hasNonAlgebricAggregateFunctions - - // If result expressions' data types are all fixed length, we generate unsafe rows - // (We have this requirement instead of check the result of UnsafeProjection.canSupport - // is because we use a mutable projection to generate the result). - override def outputsUnsafeRows: Boolean = { - // resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength) - // TODO: Supports generating UnsafeRows. We can just re-enable the line above and fix - // any issue we get. - false - } - - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { - if (requiresSortedInput) { - // TODO: We should not sort the input rows if they are just in reversed order. - groupingExpressions.map(SortOrder(_, Ascending)) :: Nil - } else { - Seq.fill(children.size)(Nil) - } - } - - override def outputOrdering: Seq[SortOrder] = { - if (requiresSortedInput) { - // It is possible that the child.outputOrdering starts with the required - // ordering expressions (e.g. we require [a] as the sort expression and the - // child's outputOrdering is [a, b]). We can only guarantee the output rows - // are sorted by values of groupingExpressions. - groupingExpressions.map(SortOrder(_, Ascending)) - } else { - Nil - } - } - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - // Because the constructor of an aggregation iterator will read at least the first row, - // we need to get the value of iter.hasNext first. - val hasInput = iter.hasNext - val useHybridIterator = - hasInput && - supportsHybridIterator && - groupingExpressions.nonEmpty - if (useHybridIterator) { - UnsafeHybridAggregationIterator.createFromInputIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection _, - child.output, - iter, - outputsUnsafeRows) - } else { - if (!hasInput && groupingExpressions.nonEmpty) { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. - Iterator[InternalRow]() - } else { - val outputIter = SortBasedAggregationIterator.createFromInputIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection _ , - newProjection _, - child.output, - iter, - outputsUnsafeRows) - if (!hasInput && groupingExpressions.isEmpty) { - // There is no input and there is no grouping expressions. - // We need to output a single row as the output. - Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) - } else { - outputIter - } - } - } - } - } - - override def simpleString: String = { - val iterator = if (supportsHybridIterator && groupingExpressions.nonEmpty) { - classOf[UnsafeHybridAggregationIterator].getSimpleName - } else { - classOf[SortBasedAggregationIterator].getSimpleName - } - - s"""NewAggregate with $iterator ${groupingExpressions} ${allAggregateExpressions}""" - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala new file mode 100644 index 0000000000000000000000000000000000000000..ad428ad663f307634d83cce5ee44f360f424cf08 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} +import org.apache.spark.sql.types.StructType + +case class SortBasedAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def outputsUnsafeRows: Boolean = false + + override def canProcessUnsafeRows: Boolean = false + + override def canProcessSafeRows: Boolean = true + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + } + + override def outputOrdering: Seq[SortOrder] = { + groupingExpressions.map(SortOrder(_, Ascending)) + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + // Because the constructor of an aggregation iterator will read at least the first row, + // we need to get the value of iter.hasNext first. + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator[InternalRow]() + } else { + val outputIter = SortBasedAggregationIterator.createFromInputIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection _, + newProjection _, + child.output, + iter, + outputsUnsafeRows) + if (!hasInput && groupingExpressions.isEmpty) { + // There is no input and there is no grouping expressions. + // We need to output a single row as the output. + Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + } else { + outputIter + } + } + } + } + + override def simpleString: String = { + val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}""" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 40f6bff53d2b7de71bbfadae179d98423973529f..67ebafde25ad3e5620d1a6df8ea2d32b74a30b88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -204,31 +204,5 @@ object SortBasedAggregationIterator { newMutableProjection, outputsUnsafeRows) } - - def createFromKVIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[InternalRow, InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean): SortBasedAggregationIterator = { - new SortBasedAggregationIterator( - groupingKeyAttributes, - valueAttributes, - inputKVIterator, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) - } // scalastyle:on } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala new file mode 100644 index 0000000000000000000000000000000000000000..5a0b4d47d62f8dc10c16e3a495911581dea3dbe0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} + +case class TungstenAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression2], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def outputsUnsafeRows: Boolean = true + + override def canProcessUnsafeRows: Boolean = true + + override def canProcessSafeRows: Boolean = false + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + // This is for testing. We force TungstenAggregationIterator to fall back to sort-based + // aggregation once it has processed a given number of input rows. + private val testFallbackStartsAt: Option[Int] = { + sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match { + case null | "" => None + case fallbackStartsAt => Some(fallbackStartsAt.toInt) + } + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator.empty.asInstanceOf[Iterator[UnsafeRow]] + } else { + val aggregationIterator = + new TungstenAggregationIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + completeAggregateExpressions, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + child.output, + iter.asInstanceOf[Iterator[UnsafeRow]], + testFallbackStartsAt) + + if (!hasInput && groupingExpressions.isEmpty) { + Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) + } else { + aggregationIterator + } + } + } + } + + override def simpleString: String = { + val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + + testFallbackStartsAt match { + case None => s"TungstenAggregate ${groupingExpressions} ${allAggregateExpressions}" + case Some(fallbackStartsAt) => + s"TungstenAggregateWithControlledFallback ${groupingExpressions} " + + s"${allAggregateExpressions} fallbackStartsAt=$fallbackStartsAt" + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala new file mode 100644 index 0000000000000000000000000000000000000000..b9d44aace1009b01117e5cda473fe84258770e6e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -0,0 +1,667 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.unsafe.KVIterator +import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.types.StructType + +/** + * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s. + * + * This iterator first uses hash-based aggregation to process input rows. It uses + * a hash map to store groups and their corresponding aggregation buffers. If we + * this map cannot allocate memory from [[org.apache.spark.shuffle.ShuffleMemoryManager]], + * it switches to sort-based aggregation. The process of the switch has the following step: + * - Step 1: Sort all entries of the hash map based on values of grouping expressions and + * spill them to disk. + * - Step 2: Create a external sorter based on the spilled sorted map entries. + * - Step 3: Redirect all input rows to the external sorter. + * - Step 4: Get a sorted [[KVIterator]] from the external sorter. + * - Step 5: Initialize sort-based aggregation. + * Then, this iterator works in the way of sort-based aggregation. + * + * The code of this class is organized as follows: + * - Part 1: Initializing aggregate functions. + * - Part 2: Methods and fields used by setting aggregation buffer values, + * processing input rows from inputIter, and generating output + * rows. + * - Part 3: Methods and fields used by hash-based aggregation. + * - Part 4: The function used to switch this iterator from hash-based + * aggregation to sort-based aggregation. + * - Part 5: Methods and fields used by sort-based aggregation. + * - Part 6: Loads input and process input rows. + * - Part 7: Public methods of this iterator. + * - Part 8: A utility function used to generate a result when there is no + * input and there is no grouping expression. + * + * @param groupingExpressions + * expressions for grouping keys + * @param nonCompleteAggregateExpressions + * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]], + * [[PartialMerge]], or [[Final]]. + * @param completeAggregateExpressions + * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]]. + * @param initialInputBufferOffset + * If this iterator is used to handle functions with mode [[PartialMerge]] or [[Final]]. + * The input rows have the format of `grouping keys + aggregation buffer`. + * This offset indicates the starting position of aggregation buffer in a input row. + * @param resultExpressions + * expressions for generating output rows. + * @param newMutableProjection + * the function used to create mutable projections. + * @param originalInputAttributes + * attributes of representing input rows from `inputIter`. + * @param inputIter + * the iterator containing input [[UnsafeRow]]s. + */ +class TungstenAggregationIterator( + groupingExpressions: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression2], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + originalInputAttributes: Seq[Attribute], + inputIter: Iterator[UnsafeRow], + testFallbackStartsAt: Option[Int]) + extends Iterator[UnsafeRow] with Logging { + + /////////////////////////////////////////////////////////////////////////// + // Part 1: Initializing aggregate functions. + /////////////////////////////////////////////////////////////////////////// + + // A Seq containing all AggregateExpressions. + // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final + // are at the beginning of the allAggregateExpressions. + private[this] val allAggregateExpressions: Seq[AggregateExpression2] = + nonCompleteAggregateExpressions ++ completeAggregateExpressions + + // Check to make sure we do not have more than three modes in our AggregateExpressions. + // If we have, users are hitting a bug and we throw an IllegalStateException. + if (allAggregateExpressions.map(_.mode).distinct.length > 2) { + throw new IllegalStateException( + s"$allAggregateExpressions should have no more than 2 kinds of modes.") + } + + // + // The modes of AggregateExpressions. Right now, we can handle the following mode: + // - Partial-only: + // All AggregateExpressions have the mode of Partial. + // For this case, aggregationMode is (Some(Partial), None). + // - PartialMerge-only: + // All AggregateExpressions have the mode of PartialMerge). + // For this case, aggregationMode is (Some(PartialMerge), None). + // - Final-only: + // All AggregateExpressions have the mode of Final. + // For this case, aggregationMode is (Some(Final), None). + // - Final-Complete: + // Some AggregateExpressions have the mode of Final and + // others have the mode of Complete. For this case, + // aggregationMode is (Some(Final), Some(Complete)). + // - Complete-only: + // nonCompleteAggregateExpressions is empty and we have AggregateExpressions + // with mode Complete in completeAggregateExpressions. For this case, + // aggregationMode is (None, Some(Complete)). + // - Grouping-only: + // There is no AggregateExpression. For this case, AggregationMode is (None,None). + // + private[this] var aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = { + nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> + completeAggregateExpressions.map(_.mode).distinct.headOption + } + + // All aggregate functions. TungstenAggregationIterator only handles AlgebraicAggregates. + // If there is any functions that is not an AlgebraicAggregate, we throw an + // IllegalStateException. + private[this] val allAggregateFunctions: Array[AlgebraicAggregate] = { + if (!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])) { + throw new IllegalStateException( + "Only AlgebraicAggregates should be passed in TungstenAggregationIterator.") + } + + allAggregateExpressions + .map(_.aggregateFunction.asInstanceOf[AlgebraicAggregate]) + .toArray + } + + /////////////////////////////////////////////////////////////////////////// + // Part 2: Methods and fields used by setting aggregation buffer values, + // processing input rows from inputIter, and generating output + // rows. + /////////////////////////////////////////////////////////////////////////// + + // The projection used to initialize buffer values. + private[this] val algebraicInitialProjection: MutableProjection = { + val initExpressions = allAggregateFunctions.flatMap(_.initialValues) + newMutableProjection(initExpressions, Nil)() + } + + // Creates a new aggregation buffer and initializes buffer values. + // This functions should be only called at most three times (when we create the hash map, + // when we switch to sort-based aggregation, and when we create the re-used buffer for + // sort-based aggregation). + private def createNewAggregationBuffer(): UnsafeRow = { + val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferRowSize: Int = bufferSchema.length + + val genericMutableBuffer = new GenericMutableRow(bufferRowSize) + val unsafeProjection = + UnsafeProjection.create(bufferSchema.map(_.dataType)) + val buffer = unsafeProjection.apply(genericMutableBuffer) + algebraicInitialProjection.target(buffer)(EmptyRow) + buffer + } + + // Creates a function used to process a row based on the given inputAttributes. + private def generateProcessRow( + inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = { + + val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) + val aggregationBufferSchema = StructType.fromAttributes(aggregationBufferAttributes) + val inputSchema = StructType.fromAttributes(inputAttributes) + val unsafeRowJoiner = + GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema) + + aggregationMode match { + // Partial-only + case (Some(Partial), None) => + val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions) + val algebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + algebraicUpdateProjection.target(currentBuffer) + algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) + } + + // PartialMerge-only or Final-only + case (Some(PartialMerge), None) | (Some(Final), None) => + val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions) + // This projection is used to merge buffer values for all AlgebraicAggregates. + val algebraicMergeProjection = + newMutableProjection( + mergeExpressions, + aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(currentBuffer) + algebraicMergeProjection(unsafeRowJoiner.join(currentBuffer, row)) + } + + // Final-Complete + case (Some(Final), Some(Complete)) => + val nonCompleteAggregateFunctions: Array[AlgebraicAggregate] = + allAggregateFunctions.take(nonCompleteAggregateExpressions.length) + val completeAggregateFunctions: Array[AlgebraicAggregate] = + allAggregateFunctions.takeRight(completeAggregateExpressions.length) + + val completeOffsetExpressions = + Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + val mergeExpressions = + nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions + val finalAlgebraicMergeProjection = + newMutableProjection( + mergeExpressions, + aggregationBufferAttributes ++ inputAttributes)() + + // We do not touch buffer values of aggregate functions with the Final mode. + val finalOffsetExpressions = + Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + val updateExpressions = + finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions) + val completeAlgebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + val input = unsafeRowJoiner.join(currentBuffer, row) + // For all aggregate functions with mode Complete, update the given currentBuffer. + completeAlgebraicUpdateProjection.target(currentBuffer)(input) + + // For all aggregate functions with mode Final, merge buffer values in row to + // currentBuffer. + finalAlgebraicMergeProjection.target(currentBuffer)(input) + } + + // Complete-only + case (None, Some(Complete)) => + val completeAggregateFunctions: Array[AlgebraicAggregate] = + allAggregateFunctions.takeRight(completeAggregateExpressions.length) + + val updateExpressions = + completeAggregateFunctions.flatMap(_.updateExpressions) + val completeAlgebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + completeAlgebraicUpdateProjection.target(currentBuffer) + // For all aggregate functions with mode Complete, update the given currentBuffer. + completeAlgebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) + } + + // Grouping only. + case (None, None) => (currentBuffer: UnsafeRow, row: UnsafeRow) => {} + + case other => + throw new IllegalStateException( + s"${aggregationMode} should not be passed into TungstenAggregationIterator.") + } + } + + // Creates a function used to generate output rows. + private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = { + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + + aggregationMode match { + // Partial-only or PartialMerge-only: every output row is basically the values of + // the grouping expressions and the corresponding aggregation buffer. + case (Some(Partial), None) | (Some(PartialMerge), None) => + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { + unsafeRowJoiner.join(currentGroupingKey, currentBuffer) + } + + // Final-only, Complete-only and Final-Complete: a output row is generated based on + // resultExpressions. + case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => + val resultProjection = + UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes) + + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { + resultProjection(unsafeRowJoiner.join(currentGroupingKey, currentBuffer)) + } + + // Grouping-only: a output row is generated from values of grouping expressions. + case (None, None) => + val resultProjection = + UnsafeProjection.create(resultExpressions, groupingAttributes) + + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { + resultProjection(currentGroupingKey) + } + + case other => + throw new IllegalStateException( + s"${aggregationMode} should not be passed into TungstenAggregationIterator.") + } + } + + // An UnsafeProjection used to extract grouping keys from the input rows. + private[this] val groupProjection = + UnsafeProjection.create(groupingExpressions, originalInputAttributes) + + // A function used to process a input row. Its first argument is the aggregation buffer + // and the second argument is the input row. + private[this] var processRow: (UnsafeRow, UnsafeRow) => Unit = + generateProcessRow(originalInputAttributes) + + // A function used to generate output rows based on the grouping keys (first argument) + // and the corresponding aggregation buffer (second argument). + private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow = + generateResultProjection() + + // An aggregation buffer containing initial buffer values. It is used to + // initialize other aggregation buffers. + private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + + /////////////////////////////////////////////////////////////////////////// + // Part 3: Methods and fields used by hash-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // This is the hash map used for hash-based aggregation. It is backed by an + // UnsafeFixedWidthAggregationMap and it is used to store + // all groups and their corresponding aggregation buffers for hash-based aggregation. + private[this] val hashMap = new UnsafeFixedWidthAggregationMap( + initialAggregationBuffer, + StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)), + StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), + TaskContext.get.taskMemoryManager(), + SparkEnv.get.shuffleMemoryManager, + 1024 * 16, // initial capacity + SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"), + false // disable tracking of performance metrics + ) + + // The function used to read and process input rows. When processing input rows, + // it first uses hash-based aggregation by putting groups and their buffers in + // hashMap. If we could not allocate more memory for the map, we switch to + // sort-based aggregation (by calling switchToSortBasedAggregation). + private def processInputs(): Unit = { + while (!sortBased && inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + val buffer: UnsafeRow = hashMap.getAggregationBuffer(groupingKey) + if (buffer == null) { + // buffer == null means that we could not allocate more memory. + // Now, we need to spill the map and switch to sort-based aggregation. + switchToSortBasedAggregation(groupingKey, newInput) + } else { + processRow(buffer, newInput) + } + } + } + + // This function is only used for testing. It basically the same as processInputs except + // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have + // been processed. + private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { + var i = 0 + while (!sortBased && inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + val buffer: UnsafeRow = if (i < fallbackStartsAt) { + hashMap.getAggregationBuffer(groupingKey) + } else { + null + } + if (buffer == null) { + // buffer == null means that we could not allocate more memory. + // Now, we need to spill the map and switch to sort-based aggregation. + switchToSortBasedAggregation(groupingKey, newInput) + } else { + processRow(buffer, newInput) + } + i += 1 + } + } + + // The iterator created from hashMap. It is used to generate output rows when we + // are using hash-based aggregation. + private[this] var aggregationBufferMapIterator: KVIterator[UnsafeRow, UnsafeRow] = null + + // Indicates if aggregationBufferMapIterator still has key-value pairs. + private[this] var mapIteratorHasNext: Boolean = false + + /////////////////////////////////////////////////////////////////////////// + // Part 4: The function used to switch this iterator from hash-based + // aggregation to sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = { + logInfo("falling back to sort based aggregation.") + // Step 1: Get the ExternalSorter containing sorted entries of the map. + val externalSorter: UnsafeKVExternalSorter = hashMap.destructAndCreateExternalSorter() + + // Step 2: Free the memory used by the map. + hashMap.free() + + // Step 3: If we have aggregate function with mode Partial or Complete, + // we need to process input rows to get aggregation buffer. + // So, later in the sort-based aggregation iterator, we can do merge. + // If aggregate functions are with mode Final and PartialMerge, + // we just need to project the aggregation buffer from an input row. + val needsProcess = aggregationMode match { + case (Some(Partial), None) => true + case (None, Some(Complete)) => true + case (Some(Final), Some(Complete)) => true + case _ => false + } + + if (needsProcess) { + // First, we create a buffer. + val buffer = createNewAggregationBuffer() + + // Process firstKey and firstInput. + // Initialize buffer. + buffer.copyFrom(initialAggregationBuffer) + processRow(buffer, firstInput) + externalSorter.insertKV(firstKey, buffer) + + // Process the rest of input rows. + while (inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + buffer.copyFrom(initialAggregationBuffer) + processRow(buffer, newInput) + externalSorter.insertKV(groupingKey, buffer) + } + } else { + // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer. + // We need to project the aggregation buffer part from an input row. + val buffer = createNewAggregationBuffer() + // The originalInputAttributes are using cloneBufferAttributes. So, we need to use + // allAggregateFunctions.flatMap(_.cloneBufferAttributes). + val bufferExtractor = newMutableProjection( + allAggregateFunctions.flatMap(_.cloneBufferAttributes), + originalInputAttributes)() + bufferExtractor.target(buffer) + + // Insert firstKey and its buffer. + bufferExtractor(firstInput) + externalSorter.insertKV(firstKey, buffer) + + // Insert the rest of input rows. + while (inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + bufferExtractor(newInput) + externalSorter.insertKV(groupingKey, buffer) + } + } + + // Set aggregationMode, processRow, and generateOutput for sort-based aggregation. + val newAggregationMode = aggregationMode match { + case (Some(Partial), None) => (Some(PartialMerge), None) + case (None, Some(Complete)) => (Some(Final), None) + case (Some(Final), Some(Complete)) => (Some(Final), None) + case other => other + } + aggregationMode = newAggregationMode + + // Basically the value of the KVIterator returned by externalSorter + // will just aggregation buffer. At here, we use cloneBufferAttributes. + val newInputAttributes: Seq[Attribute] = + allAggregateFunctions.flatMap(_.cloneBufferAttributes) + + // Set up new processRow and generateOutput. + processRow = generateProcessRow(newInputAttributes) + generateOutput = generateResultProjection() + + // Step 5: Get the sorted iterator from the externalSorter. + sortedKVIterator = externalSorter.sortedIterator() + + // Step 6: Pre-load the first key-value pair from the sorted iterator to make + // hasNext idempotent. + sortedInputHasNewGroup = sortedKVIterator.next() + + // Copy the first key and value (aggregation buffer). + if (sortedInputHasNewGroup) { + val key = sortedKVIterator.getKey + val value = sortedKVIterator.getValue + nextGroupingKey = key.copy() + currentGroupingKey = key.copy() + firstRowInNextGroup = value.copy() + } + + // Step 7: set sortBased to true. + sortBased = true + } + + /////////////////////////////////////////////////////////////////////////// + // Part 5: Methods and fields used by sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // Indicates if we are using sort-based aggregation. Because we first try to use + // hash-based aggregation, its initial value is false. + private[this] var sortBased: Boolean = false + + // The KVIterator containing input rows for the sort-based aggregation. It will be + // set in switchToSortBasedAggregation when we switch to sort-based aggregation. + private[this] var sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = null + + // The grouping key of the current group. + private[this] var currentGroupingKey: UnsafeRow = null + + // The grouping key of next group. + private[this] var nextGroupingKey: UnsafeRow = null + + // The first row of next group. + private[this] var firstRowInNextGroup: UnsafeRow = null + + // Indicates if we has new group of rows from the sorted input iterator. + private[this] var sortedInputHasNewGroup: Boolean = false + + // The aggregation buffer used by the sort-based aggregation. + private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + + // Processes rows in the current group. It will stop when it find a new group. + private def processCurrentSortedGroup(): Unit = { + // First, we need to copy nextGroupingKey to currentGroupingKey. + currentGroupingKey.copyFrom(nextGroupingKey) + // Now, we will start to find all rows belonging to this group. + // We create a variable to track if we see the next group. + var findNextPartition = false + // firstRowInNextGroup is the first row of this group. We first process it. + processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + + // The search will stop when we see the next group or there is no + // input row left in the iter. + // Pre-load the first key-value pair to make the condition of the while loop + // has no action (we do not trigger loading a new key-value pair + // when we evaluate the condition). + var hasNext = sortedKVIterator.next() + while (!findNextPartition && hasNext) { + // Get the grouping key and value (aggregation buffer). + val groupingKey = sortedKVIterator.getKey + val inputAggregationBuffer = sortedKVIterator.getValue + + // Check if the current row belongs the current input row. + if (currentGroupingKey.equals(groupingKey)) { + processRow(sortBasedAggregationBuffer, inputAggregationBuffer) + + hasNext = sortedKVIterator.next() + } else { + // We find a new group. + findNextPartition = true + // copyFrom will fail when + nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy() + firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = inputAggregationBuffer.copy() + + } + } + // We have not seen a new group. It means that there is no new row in the input + // iter. The current group is the last group of the sortedKVIterator. + if (!findNextPartition) { + sortedInputHasNewGroup = false + sortedKVIterator.close() + } + } + + /////////////////////////////////////////////////////////////////////////// + // Part 6: Loads input rows and setup aggregationBufferMapIterator if we + // have not switched to sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // Starts to process input rows. + testFallbackStartsAt match { + case None => + processInputs() + case Some(fallbackStartsAt) => + // This is the testing path. processInputsWithControlledFallback is same as processInputs + // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows + // have been processed. + processInputsWithControlledFallback(fallbackStartsAt) + } + + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + if (!sortBased) { + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!mapIteratorHasNext) { + hashMap.free() + } + } + + /////////////////////////////////////////////////////////////////////////// + // Par 7: Iterator's public methods. + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = { + (sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext) + } + + override final def next(): UnsafeRow = { + if (hasNext) { + if (sortBased) { + // Process the current group. + processCurrentSortedGroup() + // Generate output row for the current group. + val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer) + // Initialize buffer values for the next group. + sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) + + outputRow + } else { + // We did not fall back to sort-based aggregation. + val result = + generateOutput( + aggregationBufferMapIterator.getKey, + aggregationBufferMapIterator.getValue) + + // Pre-load next key-value pair form aggregationBufferMapIterator to make hasNext + // idempotent. + mapIteratorHasNext = aggregationBufferMapIterator.next() + + if (!mapIteratorHasNext) { + // If there is no input from aggregationBufferMapIterator, we copy current result. + val resultCopy = result.copy() + // Then, we free the map. + hashMap.free() + + resultCopy + } else { + result + } + } + } else { + // no more result + throw new NoSuchElementException + } + } + + /////////////////////////////////////////////////////////////////////////// + // Part 8: A utility function used to generate a output row when there is no + // input and there is no grouping expression. + /////////////////////////////////////////////////////////////////////////// + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { + if (groupingExpressions.isEmpty) { + sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) + // We create a output row and copy it. So, we can free the map. + val resultCopy = + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() + hashMap.free() + resultCopy + } else { + throw new IllegalStateException( + "This method should not be called when groupingExpressions is not empty.") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala deleted file mode 100644 index b465787fe8cbd38a73bc1c618cdf04a73557f2be..0000000000000000000000000000000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} -import org.apache.spark.sql.types.StructType - -/** - * An iterator used to evaluate [[AggregateFunction2]]. - * It first tries to use in-memory hash-based aggregation. If we cannot allocate more - * space for the hash map, we spill the sorted map entries, free the map, and then - * switch to sort-based aggregation. - */ -class UnsafeHybridAggregationIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[UnsafeRow, InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) - extends AggregationIterator( - groupingKeyAttributes, - valueAttributes, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) { - - require(groupingKeyAttributes.nonEmpty) - - /////////////////////////////////////////////////////////////////////////// - // Unsafe Aggregation buffers - /////////////////////////////////////////////////////////////////////////// - - // This is the Unsafe Aggregation Map used to store all buffers. - private[this] val buffers = new UnsafeFixedWidthAggregationMap( - newBuffer, - StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)), - StructType.fromAttributes(groupingKeyAttributes), - TaskContext.get.taskMemoryManager(), - SparkEnv.get.shuffleMemoryManager, - 1024 * 16, // initial capacity - SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"), - false // disable tracking of performance metrics - ) - - override protected def newBuffer: UnsafeRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) - val bufferRowSize: Int = bufferSchema.length - - val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val unsafeProjection = - UnsafeProjection.create(bufferSchema.map(_.dataType)) - val buffer = unsafeProjection.apply(genericMutableBuffer) - initializeBuffer(buffer) - buffer - } - - /////////////////////////////////////////////////////////////////////////// - // Methods and variables related to switching to sort-based aggregation - /////////////////////////////////////////////////////////////////////////// - private[this] var sortBased = false - - private[this] var sortBasedAggregationIterator: SortBasedAggregationIterator = _ - - // The value part of the input KV iterator is used to store original input values of - // aggregate functions, we need to convert them to aggregation buffers. - private def processOriginalInput( - firstKey: UnsafeRow, - firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = { - new KVIterator[UnsafeRow, UnsafeRow] { - private[this] var isFirstRow = true - - private[this] var groupingKey: UnsafeRow = _ - - private[this] val buffer: UnsafeRow = newBuffer - - override def next(): Boolean = { - initializeBuffer(buffer) - if (isFirstRow) { - isFirstRow = false - groupingKey = firstKey - processRow(buffer, firstValue) - - true - } else if (inputKVIterator.next()) { - groupingKey = inputKVIterator.getKey() - val value = inputKVIterator.getValue() - processRow(buffer, value) - - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): UnsafeRow = { - buffer - } - - override def close(): Unit = { - // Do nothing. - } - } - } - - // The value of the input KV Iterator has the format of groupingExprs + aggregation buffer. - // We need to project the aggregation buffer out. - private def projectInputBufferToUnsafe( - firstKey: UnsafeRow, - firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = { - new KVIterator[UnsafeRow, UnsafeRow] { - private[this] var isFirstRow = true - - private[this] var groupingKey: UnsafeRow = _ - - private[this] val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) - - private[this] val value: UnsafeRow = { - val genericMutableRow = new GenericMutableRow(bufferSchema.length) - UnsafeProjection.create(bufferSchema.map(_.dataType)).apply(genericMutableRow) - } - - private[this] val projectInputBuffer = { - newMutableProjection(bufferSchema, valueAttributes)().target(value) - } - - override def next(): Boolean = { - if (isFirstRow) { - isFirstRow = false - groupingKey = firstKey - projectInputBuffer(firstValue) - - true - } else if (inputKVIterator.next()) { - groupingKey = inputKVIterator.getKey() - projectInputBuffer(inputKVIterator.getValue()) - - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): UnsafeRow = { - value - } - - override def close(): Unit = { - // Do nothing. - } - } - } - - /** - * We need to fall back to sort based aggregation because we do not have enough memory - * for our in-memory hash map (i.e. `buffers`). - */ - private def switchToSortBasedAggregation( - currentGroupingKey: UnsafeRow, - currentRow: InternalRow): Unit = { - logInfo("falling back to sort based aggregation.") - - // Step 1: Get the ExternalSorter containing entries of the map. - val externalSorter = buffers.destructAndCreateExternalSorter() - - // Step 2: Free the memory used by the map. - buffers.free() - - // Step 3: If we have aggregate function with mode Partial or Complete, - // we need to process them to get aggregation buffer. - // So, later in the sort-based aggregation iterator, we can do merge. - // If aggregate functions are with mode Final and PartialMerge, - // we just need to project the aggregation buffer from the input. - val needsProcess = aggregationMode match { - case (Some(Partial), None) => true - case (None, Some(Complete)) => true - case (Some(Final), Some(Complete)) => true - case _ => false - } - - val processedIterator = if (needsProcess) { - processOriginalInput(currentGroupingKey, currentRow) - } else { - // The input value's format is groupingExprs + buffer. - // We need to project the buffer part out. - projectInputBufferToUnsafe(currentGroupingKey, currentRow) - } - - // Step 4: Redirect processedIterator to externalSorter. - while (processedIterator.next()) { - externalSorter.insertKV(processedIterator.getKey(), processedIterator.getValue()) - } - - // Step 5: Get the sorted iterator from the externalSorter. - val sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = externalSorter.sortedIterator() - - // Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator. - // For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator - // will be PartialMerge. For a aggregate function with mode Complete, - // its mode in the SortBasedAggregationIterator will be Final. - val newNonCompleteAggregateExpressions = allAggregateExpressions.map { - case AggregateExpression2(func, Partial, isDistinct) => - AggregateExpression2(func, PartialMerge, isDistinct) - case AggregateExpression2(func, Complete, isDistinct) => - AggregateExpression2(func, Final, isDistinct) - case other => other - } - val newNonCompleteAggregateAttributes = - nonCompleteAggregateAttributes ++ completeAggregateAttributes - - val newValueAttributes = - allAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) - - sortBasedAggregationIterator = SortBasedAggregationIterator.createFromKVIterator( - groupingKeyAttributes = groupingKeyAttributes, - valueAttributes = newValueAttributes, - inputKVIterator = sortedKVIterator.asInstanceOf[KVIterator[InternalRow, InternalRow]], - nonCompleteAggregateExpressions = newNonCompleteAggregateExpressions, - nonCompleteAggregateAttributes = newNonCompleteAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = resultExpressions, - newMutableProjection = newMutableProjection, - outputsUnsafeRows = outputsUnsafeRows) - } - - /////////////////////////////////////////////////////////////////////////// - // Methods used to initialize this iterator. - /////////////////////////////////////////////////////////////////////////// - - /** Starts to read input rows and falls back to sort-based aggregation if necessary. */ - protected def initialize(): Unit = { - var hasNext = inputKVIterator.next() - while (!sortBased && hasNext) { - val groupingKey = inputKVIterator.getKey() - val currentRow = inputKVIterator.getValue() - val buffer = buffers.getAggregationBuffer(groupingKey) - if (buffer == null) { - // buffer == null means that we could not allocate more memory. - // Now, we need to spill the map and switch to sort-based aggregation. - switchToSortBasedAggregation(groupingKey, currentRow) - sortBased = true - } else { - processRow(buffer, currentRow) - hasNext = inputKVIterator.next() - } - } - } - - // This is the starting point of this iterator. - initialize() - - // Creates the iterator for the Hash Aggregation Map after we have populated - // contents of that map. - private[this] val aggregationBufferMapIterator = buffers.iterator() - - private[this] var _mapIteratorHasNext = false - - // Pre-load the first key-value pair from the map to make hasNext idempotent. - if (!sortBased) { - _mapIteratorHasNext = aggregationBufferMapIterator.next() - // If the map is empty, we just free it. - if (!_mapIteratorHasNext) { - buffers.free() - } - } - - /////////////////////////////////////////////////////////////////////////// - // Iterator's public methods - /////////////////////////////////////////////////////////////////////////// - - override final def hasNext: Boolean = { - (sortBased && sortBasedAggregationIterator.hasNext) || (!sortBased && _mapIteratorHasNext) - } - - - override final def next(): InternalRow = { - if (hasNext) { - if (sortBased) { - sortBasedAggregationIterator.next() - } else { - // We did not fall back to the sort-based aggregation. - val result = - generateOutput( - aggregationBufferMapIterator.getKey, - aggregationBufferMapIterator.getValue) - // Pre-load next key-value pair form aggregationBufferMapIterator. - _mapIteratorHasNext = aggregationBufferMapIterator.next() - - if (!_mapIteratorHasNext) { - val resultCopy = result.copy() - buffers.free() - resultCopy - } else { - result - } - } - } else { - // no more result - throw new NoSuchElementException - } - } -} - -object UnsafeHybridAggregationIterator { - // scalastyle:off - def createFromInputIterator( - groupingExprs: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], - outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = { - new UnsafeHybridAggregationIterator( - groupingExprs.map(_.toAttribute), - inputAttributes, - AggregationIterator.unsafeKVIterator(groupingExprs, inputAttributes, inputIter), - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) - } - // scalastyle:on -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 960be08f84d941c9e39ce180c9a1c6e06110a275..80816a095ea8c0a295a8a16c2fcd19046fafebb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -17,20 +17,41 @@ package org.apache.spark.sql.execution.aggregate +import scala.collection.mutable + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan} +import org.apache.spark.sql.types.StructType /** * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object Utils { + def supportsTungstenAggregate( + groupingExpressions: Seq[Expression], + aggregateBufferAttributes: Seq[Attribute]): Boolean = { + val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeProjection.canSupport(groupingExpressions) + } + def planAggregateWithoutDistinct( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + // Check if we can use TungstenAggregate. + val usesTungstenAggregate = + child.sqlContext.conf.unsafeEnabled && + aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) && + supportsTungstenAggregate( + groupingExpressions, + aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) + + // 1. Create an Aggregate Operator for partial aggregations. val namedGroupingExpressions = groupingExpressions.map { case ne: NamedExpression => ne -> ne @@ -44,11 +65,23 @@ object Utils { val groupExpressionMap = namedGroupingExpressions.toMap val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) - val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } - val partialAggregate = - Aggregate( + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) + val partialResultExpressions = + namedGroupingAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + + val partialAggregate = if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = None: Option[Seq[Expression]], + groupingExpressions = namedGroupingExpressions.map(_._2), + nonCompleteAggregateExpressions = partialAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = 0, + resultExpressions = partialResultExpressions, + child = child) + } else { + SortBasedAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], groupingExpressions = namedGroupingExpressions.map(_._2), nonCompleteAggregateExpressions = partialAggregateExpressions, @@ -56,29 +89,57 @@ object Utils { completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, initialInputBufferOffset = 0, - resultExpressions = namedGroupingAttributes ++ partialAggregateAttributes, + resultExpressions = partialResultExpressions, child = child) + } // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 } - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - val finalAggregate = - Aggregate( + + val finalAggregate = if (usesTungstenAggregate) { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case agg: AggregateExpression2 => + // aggregateFunctionMap contains unique aggregate functions. + val aggregateFunction = + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1 + aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + TungstenAggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = namedGroupingAttributes.length, + resultExpressions = rewrittenResultExpressions, + child = partialAggregate) + } else { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + SortBasedAggregate( requiredChildDistributionExpressions = Some(namedGroupingAttributes), groupingExpressions = namedGroupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, @@ -88,6 +149,7 @@ object Utils { initialInputBufferOffset = namedGroupingAttributes.length, resultExpressions = rewrittenResultExpressions, child = partialAggregate) + } finalAggregate :: Nil } @@ -96,10 +158,18 @@ object Utils { groupingExpressions: Seq[Expression], functionsWithDistinct: Seq[AggregateExpression2], functionsWithoutDistinct: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct + val usesTungstenAggregate = + child.sqlContext.conf.unsafeEnabled && + aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) && + supportsTungstenAggregate( + groupingExpressions, + aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) + // 1. Create an Aggregate Operator for partial aggregations. // The grouping expressions are original groupingExpressions and // distinct columns. For example, for avg(distinct value) ... group by key @@ -129,19 +199,26 @@ object Utils { val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute) - val partialAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, Partial, false) - } - val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } + val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) val partialAggregateGroupingExpressions = (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2) val partialAggregateResult = - namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes - val partialAggregate = - Aggregate( + namedGroupingAttributes ++ + distinctColumnAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + val partialAggregate = if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = None: Option[Seq[Expression]], + groupingExpressions = partialAggregateGroupingExpressions, + nonCompleteAggregateExpressions = partialAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = 0, + resultExpressions = partialAggregateResult, + child = child) + } else { + SortBasedAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], groupingExpressions = partialAggregateGroupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, @@ -151,20 +228,27 @@ object Utils { initialInputBufferOffset = 0, resultExpressions = partialAggregateResult, child = child) + } // 2. Create an Aggregate Operator for partial merge aggregations. - val partialMergeAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, PartialMerge, false) - } + val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } + partialMergeAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) val partialMergeAggregateResult = - namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes - val partialMergeAggregate = - Aggregate( + namedGroupingAttributes ++ + distinctColumnAttributes ++ + partialMergeAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + val partialMergeAggregate = if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, + nonCompleteAggregateExpressions = partialMergeAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = partialMergeAggregateResult, + child = partialAggregate) + } else { + SortBasedAggregate( requiredChildDistributionExpressions = Some(namedGroupingAttributes), groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, @@ -174,48 +258,91 @@ object Utils { initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) + } // 3. Create an Aggregate Operator for partial merge aggregations. - val finalAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, Final, false) - } + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 } + // Create a map to store those rewritten aggregate functions. We always need to use + // both function and its corresponding isDistinct flag as the key because function itself + // does not knows if it is has distinct keyword or now. + val rewrittenAggregateFunctions = + mutable.Map.empty[(AggregateFunction2, Boolean), AggregateFunction2] val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. - case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) => + case agg @ AggregateExpression2(aggregateFunction, mode, true) => val rewrittenAggregateFunction = aggregateFunction.transformDown { case expr if distinctColumnExpressionMap.contains(expr) => distinctColumnExpressionMap(expr).toAttribute }.asInstanceOf[AggregateFunction2] + // Because we have rewritten the aggregate function, we use rewrittenAggregateFunctions + // to track the old version and the new version of this function. + rewrittenAggregateFunctions += (aggregateFunction, true) -> rewrittenAggregateFunction // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. val rewrittenAggregateExpression = - AggregateExpression2(rewrittenAggregateFunction, Complete, false) + AggregateExpression2(rewrittenAggregateFunction, Complete, true) - val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct) + val aggregateFunctionAttribute = + aggregateFunctionMap(agg.aggregateFunction, true)._2 (rewrittenAggregateExpression -> aggregateFunctionAttribute) }.unzip - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transform { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - val finalAndCompleteAggregate = - Aggregate( + val finalAndCompleteAggregate = if (usesTungstenAggregate) { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transform { + case agg: AggregateExpression2 => + val function = agg.aggregateFunction + val isDistinct = agg.isDistinct + val aggregateFunction = + if (rewrittenAggregateFunctions.contains(function, isDistinct)) { + // If this function has been rewritten, we get the rewritten version from + // rewrittenAggregateFunctions. + rewrittenAggregateFunctions(function, isDistinct) + } else { + // Oterwise, we get it from aggregateFunctionMap, which contains unique + // aggregate functions that have not been rewritten. + aggregateFunctionMap(function, isDistinct)._1 + } + aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + TungstenAggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + completeAggregateExpressions = completeAggregateExpressions, + initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = rewrittenResultExpressions, + child = partialMergeAggregate) + } else { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transform { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + SortBasedAggregate( requiredChildDistributionExpressions = Some(namedGroupingAttributes), groupingExpressions = namedGroupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, @@ -225,6 +352,7 @@ object Utils { initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, resultExpressions = rewrittenResultExpressions, child = partialMergeAggregate) + } finalAndCompleteAggregate :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cef40dd324d9efd5fd6fb000b43d9ec7e74c17ff..c64aa7a07dc2bfb149ae799512467c11e62824cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -262,7 +262,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = sql(sqlText) // First, check if we have GeneratedAggregate. val hasGeneratedAgg = df.queryExecution.executedPlan - .collect { case _: aggregate.Aggregate => true } + .collect { case _: aggregate.TungstenAggregate => true } .nonEmpty if (!hasGeneratedAgg) { fail( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 4b35c8fd83533c45c210442d69f9f0d7993fed37..7b5aa4763fd9ea2f6e8e052910e2f8b22e4150d0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -21,9 +21,9 @@ import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row} +import org.apache.spark.sql._ import org.scalatest.BeforeAndAfterAll -import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} +import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { @@ -141,6 +141,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Nil) } + test("null literal") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(null), + | COUNT(null), + | FIRST(null), + | LAST(null), + | MAX(null), + | MIN(null), + | SUM(null) + """.stripMargin), + Row(null, 0, null, null, null, null, null) :: Nil) + } + test("only do grouping") { checkAnswer( sqlContext.sql( @@ -266,13 +282,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be |SELECT avg(value) FROM agg1 """.stripMargin), Row(11.125) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT avg(null) - """.stripMargin), - Row(null) :: Nil) } test("udaf") { @@ -364,7 +373,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be | max(distinct value1) |FROM agg2 """.stripMargin), - Row(-60, 70.0, 101.0/9.0, 5.6, 100.0)) + Row(-60, 70.0, 101.0/9.0, 5.6, 100)) checkAnswer( sqlContext.sql( @@ -402,6 +411,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: Row(3, null, 3.0, null, null, null) :: Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value1), + | count(*), + | count(1), + | count(DISTINCT value1), + | key + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(3, 3, 3, 2, 1) :: + Row(3, 4, 4, 2, 2) :: + Row(0, 2, 2, 0, 3) :: + Row(3, 4, 4, 3, null) :: Nil) } test("test count") { @@ -496,7 +522,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be |FROM agg1 |GROUP BY key """.stripMargin).queryExecution.executedPlan.collect { - case agg: aggregate.Aggregate => agg + case agg: aggregate.SortBasedAggregate => agg + case agg: aggregate.TungstenAggregate => agg } val message = "We should fallback to the old aggregation code path if " + @@ -537,3 +564,58 @@ class TungstenAggregationQuerySuite extends AggregationQuerySuite { sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) } } + +class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { + + var originalUnsafeEnabled: Boolean = _ + + override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") + } + + override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit = { + (0 to 2).foreach { fallbackStartsAt => + sqlContext.setConf( + "spark.sql.TungstenAggregate.testFallbackStartsAt", + fallbackStartsAt.toString) + + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + val newActual = DataFrame(sqlContext, actual.logicalPlan) + + QueryTest.checkAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using TungstenAggregate with + |controlled fallback (it falls back to sort-based aggregation once it has processed + |$fallbackStartsAt input rows). The query is + |${actual.queryExecution} + | + |$errorMessage + """.stripMargin + + fail(newErrorMessage) + case None => + } + } + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) + } +}