Skip to content
Snippets Groups Projects
Commit 595012ea authored by Josh Rosen's avatar Josh Rosen
Browse files

[SPARK-11053] Remove use of KVIterator in SortBasedAggregationIterator

SortBasedAggregationIterator uses a KVIterator interface in order to process input rows as key-value pairs, but this use of KVIterator is unnecessary, slightly complicates the code, and might hurt performance. This patch refactors this code to remove the use of this extra layer of iterator wrapping and simplifies other parts of the code in the process.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #9066 from JoshRosen/sort-iterator-cleanup.
parent a16396df
No related branches found
No related tags found
No related merge requests found
......@@ -21,7 +21,6 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.unsafe.KVIterator
import scala.collection.mutable.ArrayBuffer
......@@ -412,85 +411,3 @@ abstract class AggregationIterator(
*/
protected def newBuffer: MutableRow
}
object AggregationIterator {
def kvIterator(
groupingExpressions: Seq[NamedExpression],
newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
inputAttributes: Seq[Attribute],
inputIter: Iterator[InternalRow]): KVIterator[InternalRow, InternalRow] = {
new KVIterator[InternalRow, InternalRow] {
private[this] val groupingKeyGenerator = newProjection(groupingExpressions, inputAttributes)
private[this] var groupingKey: InternalRow = _
private[this] var value: InternalRow = _
override def next(): Boolean = {
if (inputIter.hasNext) {
// Read the next input row.
val inputRow = inputIter.next()
// Get groupingKey based on groupingExpressions.
groupingKey = groupingKeyGenerator(inputRow)
// The value is the inputRow.
value = inputRow
true
} else {
false
}
}
override def getKey(): InternalRow = {
groupingKey
}
override def getValue(): InternalRow = {
value
}
override def close(): Unit = {
// Do nothing
}
}
}
def unsafeKVIterator(
groupingExpressions: Seq[NamedExpression],
inputAttributes: Seq[Attribute],
inputIter: Iterator[InternalRow]): KVIterator[UnsafeRow, InternalRow] = {
new KVIterator[UnsafeRow, InternalRow] {
private[this] val groupingKeyGenerator =
UnsafeProjection.create(groupingExpressions, inputAttributes)
private[this] var groupingKey: UnsafeRow = _
private[this] var value: InternalRow = _
override def next(): Boolean = {
if (inputIter.hasNext) {
// Read the next input row.
val inputRow = inputIter.next()
// Get groupingKey based on groupingExpressions.
groupingKey = groupingKeyGenerator.apply(inputRow)
// The value is the inputRow.
value = inputRow
true
} else {
false
}
}
override def getKey(): UnsafeRow = {
groupingKey
}
override def getValue(): InternalRow = {
value
}
override def close(): Unit = {
// Do nothing
}
}
}
}
......@@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode}
import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.StructType
case class SortBasedAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
......@@ -79,18 +78,23 @@ case class SortBasedAggregate(
// so return an empty iterator.
Iterator[InternalRow]()
} else {
val outputIter = SortBasedAggregationIterator.createFromInputIterator(
groupingExpressions,
val groupingKeyProjection = if (UnsafeProjection.canSupport(groupingExpressions)) {
UnsafeProjection.create(groupingExpressions, child.output)
} else {
newMutableProjection(groupingExpressions, child.output)()
}
val outputIter = new SortBasedAggregationIterator(
groupingKeyProjection,
groupingExpressions.map(_.toAttribute),
child.output,
iter,
nonCompleteAggregateExpressions,
nonCompleteAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
initialInputBufferOffset,
resultExpressions,
newMutableProjection _,
newProjection _,
child.output,
iter,
newMutableProjection,
outputsUnsafeRows,
numInputRows,
numOutputRows)
......
......@@ -21,16 +21,16 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.execution.metric.LongSQLMetric
import org.apache.spark.unsafe.KVIterator
/**
* An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been
* sorted by values of [[groupingKeyAttributes]].
*/
class SortBasedAggregationIterator(
groupingKeyProjection: InternalRow => InternalRow,
groupingKeyAttributes: Seq[Attribute],
valueAttributes: Seq[Attribute],
inputKVIterator: KVIterator[InternalRow, InternalRow],
inputIterator: Iterator[InternalRow],
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
......@@ -90,6 +90,22 @@ class SortBasedAggregationIterator(
// The aggregation buffer used by the sort-based aggregation.
private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer
protected def initialize(): Unit = {
if (inputIterator.hasNext) {
initializeBuffer(sortBasedAggregationBuffer)
val inputRow = inputIterator.next()
nextGroupingKey = groupingKeyProjection(inputRow).copy()
firstRowInNextGroup = inputRow.copy()
numInputRows += 1
sortedInputHasNewGroup = true
} else {
// This inputIter is empty.
sortedInputHasNewGroup = false
}
}
initialize()
/** Processes rows in the current group. It will stop when it find a new group. */
protected def processCurrentSortedGroup(): Unit = {
currentGroupingKey = nextGroupingKey
......@@ -101,18 +117,15 @@ class SortBasedAggregationIterator(
// The search will stop when we see the next group or there is no
// input row left in the iter.
var hasNext = inputKVIterator.next()
while (!findNextPartition && hasNext) {
while (!findNextPartition && inputIterator.hasNext) {
// Get the grouping key.
val groupingKey = inputKVIterator.getKey
val currentRow = inputKVIterator.getValue
val currentRow = inputIterator.next()
val groupingKey = groupingKeyProjection(currentRow)
numInputRows += 1
// Check if the current row belongs the current input row.
if (currentGroupingKey == groupingKey) {
processRow(sortBasedAggregationBuffer, currentRow)
hasNext = inputKVIterator.next()
} else {
// We find a new group.
findNextPartition = true
......@@ -149,68 +162,8 @@ class SortBasedAggregationIterator(
}
}
protected def initialize(): Unit = {
if (inputKVIterator.next()) {
initializeBuffer(sortBasedAggregationBuffer)
nextGroupingKey = inputKVIterator.getKey().copy()
firstRowInNextGroup = inputKVIterator.getValue().copy()
numInputRows += 1
sortedInputHasNewGroup = true
} else {
// This inputIter is empty.
sortedInputHasNewGroup = false
}
}
initialize()
def outputForEmptyGroupingKeyWithoutInput(): InternalRow = {
initializeBuffer(sortBasedAggregationBuffer)
generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer)
}
}
object SortBasedAggregationIterator {
// scalastyle:off
def createFromInputIterator(
groupingExprs: Seq[NamedExpression],
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
completeAggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
inputAttributes: Seq[Attribute],
inputIter: Iterator[InternalRow],
outputsUnsafeRows: Boolean,
numInputRows: LongSQLMetric,
numOutputRows: LongSQLMetric): SortBasedAggregationIterator = {
val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) {
AggregationIterator.unsafeKVIterator(
groupingExprs,
inputAttributes,
inputIter).asInstanceOf[KVIterator[InternalRow, InternalRow]]
} else {
AggregationIterator.kvIterator(groupingExprs, newProjection, inputAttributes, inputIter)
}
new SortBasedAggregationIterator(
groupingExprs.map(_.toAttribute),
inputAttributes,
kvIterator,
nonCompleteAggregateExpressions,
nonCompleteAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
outputsUnsafeRows,
numInputRows,
numOutputRows)
}
// scalastyle:on
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment