diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 5f7341e88c7c9c52162adec56e9d9c8aa8ae906e..8e0fbd109b4133f617d4940888b1f98e8e14970f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -21,7 +21,6 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.unsafe.KVIterator import scala.collection.mutable.ArrayBuffer @@ -412,85 +411,3 @@ abstract class AggregationIterator( */ protected def newBuffer: MutableRow } - -object AggregationIterator { - def kvIterator( - groupingExpressions: Seq[NamedExpression], - newProjection: (Seq[Expression], Seq[Attribute]) => Projection, - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]): KVIterator[InternalRow, InternalRow] = { - new KVIterator[InternalRow, InternalRow] { - private[this] val groupingKeyGenerator = newProjection(groupingExpressions, inputAttributes) - - private[this] var groupingKey: InternalRow = _ - - private[this] var value: InternalRow = _ - - override def next(): Boolean = { - if (inputIter.hasNext) { - // Read the next input row. - val inputRow = inputIter.next() - // Get groupingKey based on groupingExpressions. - groupingKey = groupingKeyGenerator(inputRow) - // The value is the inputRow. - value = inputRow - true - } else { - false - } - } - - override def getKey(): InternalRow = { - groupingKey - } - - override def getValue(): InternalRow = { - value - } - - override def close(): Unit = { - // Do nothing - } - } - } - - def unsafeKVIterator( - groupingExpressions: Seq[NamedExpression], - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]): KVIterator[UnsafeRow, InternalRow] = { - new KVIterator[UnsafeRow, InternalRow] { - private[this] val groupingKeyGenerator = - UnsafeProjection.create(groupingExpressions, inputAttributes) - - private[this] var groupingKey: UnsafeRow = _ - - private[this] var value: InternalRow = _ - - override def next(): Boolean = { - if (inputIter.hasNext) { - // Read the next input row. - val inputRow = inputIter.next() - // Get groupingKey based on groupingExpressions. - groupingKey = groupingKeyGenerator.apply(inputRow) - // The value is the inputRow. - value = inputRow - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): InternalRow = { - value - } - - override def close(): Unit = { - // Do nothing - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index f4c14a9b3556f28b8f57c33cd9e3783488a4acb4..4d37106e007f5f36afe02c7872dade0b80ae0edc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} -import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType case class SortBasedAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -79,18 +78,23 @@ case class SortBasedAggregate( // so return an empty iterator. Iterator[InternalRow]() } else { - val outputIter = SortBasedAggregationIterator.createFromInputIterator( - groupingExpressions, + val groupingKeyProjection = if (UnsafeProjection.canSupport(groupingExpressions)) { + UnsafeProjection.create(groupingExpressions, child.output) + } else { + newMutableProjection(groupingExpressions, child.output)() + } + val outputIter = new SortBasedAggregationIterator( + groupingKeyProjection, + groupingExpressions.map(_.toAttribute), + child.output, + iter, nonCompleteAggregateExpressions, nonCompleteAggregateAttributes, completeAggregateExpressions, completeAggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection _, - newProjection _, - child.output, - iter, + newMutableProjection, outputsUnsafeRows, numInputRows, numOutputRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index a9e5d175bf89590f5a3cea3e3d1ba34e3c5dcc16..64c673064f576be763c549936ef99ab004aad931 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -21,16 +21,16 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.unsafe.KVIterator /** * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been * sorted by values of [[groupingKeyAttributes]]. */ class SortBasedAggregationIterator( + groupingKeyProjection: InternalRow => InternalRow, groupingKeyAttributes: Seq[Attribute], valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[InternalRow, InternalRow], + inputIterator: Iterator[InternalRow], nonCompleteAggregateExpressions: Seq[AggregateExpression2], nonCompleteAggregateAttributes: Seq[Attribute], completeAggregateExpressions: Seq[AggregateExpression2], @@ -90,6 +90,22 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + protected def initialize(): Unit = { + if (inputIterator.hasNext) { + initializeBuffer(sortBasedAggregationBuffer) + val inputRow = inputIterator.next() + nextGroupingKey = groupingKeyProjection(inputRow).copy() + firstRowInNextGroup = inputRow.copy() + numInputRows += 1 + sortedInputHasNewGroup = true + } else { + // This inputIter is empty. + sortedInputHasNewGroup = false + } + } + + initialize() + /** Processes rows in the current group. It will stop when it find a new group. */ protected def processCurrentSortedGroup(): Unit = { currentGroupingKey = nextGroupingKey @@ -101,18 +117,15 @@ class SortBasedAggregationIterator( // The search will stop when we see the next group or there is no // input row left in the iter. - var hasNext = inputKVIterator.next() - while (!findNextPartition && hasNext) { + while (!findNextPartition && inputIterator.hasNext) { // Get the grouping key. - val groupingKey = inputKVIterator.getKey - val currentRow = inputKVIterator.getValue + val currentRow = inputIterator.next() + val groupingKey = groupingKeyProjection(currentRow) numInputRows += 1 // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { processRow(sortBasedAggregationBuffer, currentRow) - - hasNext = inputKVIterator.next() } else { // We find a new group. findNextPartition = true @@ -149,68 +162,8 @@ class SortBasedAggregationIterator( } } - protected def initialize(): Unit = { - if (inputKVIterator.next()) { - initializeBuffer(sortBasedAggregationBuffer) - - nextGroupingKey = inputKVIterator.getKey().copy() - firstRowInNextGroup = inputKVIterator.getValue().copy() - numInputRows += 1 - sortedInputHasNewGroup = true - } else { - // This inputIter is empty. - sortedInputHasNewGroup = false - } - } - - initialize() - def outputForEmptyGroupingKeyWithoutInput(): InternalRow = { initializeBuffer(sortBasedAggregationBuffer) generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer) } } - -object SortBasedAggregationIterator { - // scalastyle:off - def createFromInputIterator( - groupingExprs: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - newProjection: (Seq[Expression], Seq[Attribute]) => Projection, - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], - outputsUnsafeRows: Boolean, - numInputRows: LongSQLMetric, - numOutputRows: LongSQLMetric): SortBasedAggregationIterator = { - val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) { - AggregationIterator.unsafeKVIterator( - groupingExprs, - inputAttributes, - inputIter).asInstanceOf[KVIterator[InternalRow, InternalRow]] - } else { - AggregationIterator.kvIterator(groupingExprs, newProjection, inputAttributes, inputIter) - } - - new SortBasedAggregationIterator( - groupingExprs.map(_.toAttribute), - inputAttributes, - kvIterator, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows, - numInputRows, - numOutputRows) - } - // scalastyle:on -}