Skip to content
Snippets Groups Projects
Commit 970ab8f6 authored by Wenchen Fan's avatar Wenchen Fan Committed by Yin Huai
Browse files

[SPARK-17187][SQL][FOLLOW-UP] improve document of TypedImperativeAggregate

## What changes were proposed in this pull request?

improve the document to make it easier to understand and also mention window operator.

## How was this patch tested?

N/A

Author: Wenchen Fan <wenchen@databricks.com>

Closes #14822 from cloud-fan/object-agg.
parent 28ab1792
No related branches found
No related tags found
No related merge requests found
......@@ -390,48 +390,69 @@ abstract class DeclarativeAggregate
}
}
/**
* Aggregation function which allows **arbitrary** user-defined java object to be used as internal
* aggregation buffer object.
* aggregation buffer.
*
* {{{
* aggregation buffer for normal aggregation function `avg`
* |
* v
* +--------------+---------------+-----------------------------------+
* | sum1 (Long) | count1 (Long) | generic user-defined java objects |
* +--------------+---------------+-----------------------------------+
* ^
* |
* Aggregation buffer object for `TypedImperativeAggregate` aggregation function
* aggregation buffer for normal aggregation function `avg` aggregate buffer for `sum`
* | |
* v v
* +--------------+---------------+-----------------------------------+-------------+
* | sum1 (Long) | count1 (Long) | generic user-defined java objects | sum2 (Long) |
* +--------------+---------------+-----------------------------------+-------------+
* ^
* |
* aggregation buffer object for `TypedImperativeAggregate` aggregation function
* }}}
*
* Work flow (Partial mode aggregate at Mapper side, and Final mode aggregate at Reducer side):
* General work flow:
*
* Stage 1: initialize aggregate buffer object.
*
* 1. The framework calls `initialize(buffer: MutableRow)` to set up the empty aggregate buffer.
* 2. In `initialize`, we call `createAggregationBuffer(): T` to get the initial buffer object,
* and set it to the global buffer row.
*
*
* Stage 2: process input rows.
*
* Stage 1: Partial aggregate at Mapper side:
* If the aggregate mode is `Partial` or `Complete`:
* 1. The framework calls `update(buffer: MutableRow, input: InternalRow)` to process the input
* row.
* 2. In `update`, we get the buffer object from the global buffer row and call
* `update(buffer: T, input: InternalRow): Unit`.
*
* 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
* buffer object.
* 2. Upon each input row, the framework calls
* `update(buffer: T, input: InternalRow): Unit` to update the aggregation buffer object T.
* 3. After processing all rows of current group (group by key), the framework will serialize
* aggregation buffer object T to storage format (Array[Byte]) and persist the Array[Byte]
* to disk if needed.
* 4. The framework moves on to next group, until all groups have been processed.
* If the aggregate mode is `PartialMerge` or `Final`:
* 1. The framework call `merge(buffer: MutableRow, inputBuffer: InternalRow)` to process the
* input row, which are serialized buffer objects shuffled from other nodes.
* 2. In `merge`, we get the buffer object from the global buffer row, and get the binary data
* from input row and deserialize it to buffer object, then we call
* `merge(buffer: T, input: T): Unit` to merge these 2 buffer objects.
*
* Shuffling exchange data to Reducer tasks...
*
* Stage 2: Final mode aggregate at Reducer side:
* Stage 3: output results.
*
* If the aggregate mode is `Partial` or `PartialMerge`:
* 1. The framework calls `serializeAggregateBufferInPlace` to replace the buffer object in the
* global buffer row with binary data.
* 2. In `serializeAggregateBufferInPlace`, we get the buffer object from the global buffer row
* and call `serialize(buffer: T): Array[Byte]` to serialize the buffer object to binary.
* 3. The framework outputs buffer attributes and shuffle them to other nodes.
*
* If the aggregate mode is `Final` or `Complete`:
* 1. The framework calls `eval(buffer: InternalRow)` to calculate the final result.
* 2. In `eval`, we get the buffer object from the global buffer row and call
* `eval(buffer: T): Any` to get the final result.
* 3. The framework outputs these final results.
*
*
* Window function work flow:
* The framework calls `update(buffer: MutableRow, input: InternalRow)` several times and then
* call `eval(buffer: InternalRow)`, so there is no need for window operator to call
* `serializeAggregateBufferInPlace`.
*
* 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
* buffer object (type T) for merging.
* 2. For each aggregation output of Stage 1, The framework de-serializes the storage
* format (Array[Byte]) and produces one input aggregation object (type T).
* 3. For each input aggregation object, the framework calls `merge(buffer: T, input: T): Unit`
* to merge the input aggregation object into aggregation buffer object.
* 4. After processing all input aggregation objects of current group (group by key), the framework
* calls method `eval(buffer: T)` to generate the final output for this group.
* 5. The framework moves on to next group, until all groups have been processed.
*
* NOTE: SQL with TypedImperativeAggregate functions is planned in sort based aggregation,
* instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation
......@@ -489,25 +510,23 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
}
final override def update(buffer: MutableRow, input: InternalRow): Unit = {
val bufferObject = getField[T](buffer, mutableAggBufferOffset)
update(bufferObject, input)
update(getBufferObject(buffer), input)
}
final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = {
val bufferObject = getField[T](buffer, mutableAggBufferOffset)
val bufferObject = getBufferObject(buffer)
// The inputBuffer stores serialized aggregation buffer object produced by partial aggregate
val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset))
merge(bufferObject, inputObject)
}
final override def eval(buffer: InternalRow): Any = {
val bufferObject = getField[T](buffer, mutableAggBufferOffset)
eval(bufferObject)
eval(getBufferObject(buffer))
}
private[this] val anyObjectType = ObjectType(classOf[AnyRef])
private def getField[U](input: InternalRow, fieldIndex: Int): U = {
input.get(fieldIndex, anyObjectType).asInstanceOf[U]
private def getBufferObject(bufferRow: InternalRow): T = {
bufferRow.get(mutableAggBufferOffset, anyObjectType).asInstanceOf[T]
}
final override lazy val aggBufferAttributes: Seq[AttributeReference] = {
......@@ -524,9 +543,11 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
* In-place replaces the aggregation buffer object stored at buffer's index
* `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format
* (BinaryType).
*
* This is only called when doing Partial or PartialMerge mode aggregation, before the framework
* shuffle out aggregate buffers.
*/
final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = {
val bufferObject = getField[T](buffer, mutableAggBufferOffset)
buffer(mutableAggBufferOffset) = serialize(bufferObject)
buffer(mutableAggBufferOffset) = serialize(getBufferObject(buffer))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment