Skip to content
Snippets Groups Projects
Commit dfcfcbcc authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-11578][SQL][FOLLOW-UP] complete the user facing api for typed aggregation

Currently the user facing api for typed aggregation has some limitations:

* the customized typed aggregation must be the first of aggregation list
* the customized typed aggregation can only use long as buffer type
* the customized typed aggregation can only use flat type as result type

This PR tries to remove these limitations.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9599 from cloud-fan/agg.
parent 47735cdc
No related branches found
No related tags found
No related merge requests found
......@@ -185,6 +185,12 @@ case class ExpressionEncoder[T](
})
}
def shift(delta: Int): ExpressionEncoder[T] = {
copy(constructExpression = constructExpression transform {
case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
})
}
/**
* Returns a copy of this encoder where the expressions used to create an object given an
* input row have been modified to pull the object out from a nested struct, instead of the
......
......@@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.aggregate
import scala.language.existentials
import org.apache.spark.Logging
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{StructType, DataType}
import org.apache.spark.sql.types._
object TypedAggregateExpression {
def apply[A, B : Encoder, C : Encoder](
......@@ -67,8 +67,11 @@ case class TypedAggregateExpression(
override def nullable: Boolean = true
// TODO: this assumes flat results...
override def dataType: DataType = cEncoder.schema.head.dataType
override def dataType: DataType = if (cEncoder.flat) {
cEncoder.schema.head.dataType
} else {
cEncoder.schema
}
override def deterministic: Boolean = true
......@@ -93,32 +96,51 @@ case class TypedAggregateExpression(
case a: AttributeReference => inputMapping(a)
})
// TODO: this probably only works when we are in the first column.
val bAttributes = bEncoder.schema.toAttributes
lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes)
private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
// todo: need a more neat way to assign the value.
var i = 0
while (i < aggBufferAttributes.length) {
aggBufferSchema(i).dataType match {
case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i))
case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i))
}
i += 1
}
}
override def initialize(buffer: MutableRow): Unit = {
// TODO: We need to either force Aggregator to have a zero or we need to eliminate the need for
// this in execution.
buffer.setInt(mutableAggBufferOffset, aggregator.zero.asInstanceOf[Int])
val zero = bEncoder.toRow(aggregator.zero)
updateBuffer(buffer, zero)
}
override def update(buffer: MutableRow, input: InternalRow): Unit = {
val inputA = boundA.fromRow(input)
val currentB = boundB.fromRow(buffer)
val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer)
val merged = aggregator.reduce(currentB, inputA)
val returned = boundB.toRow(merged)
buffer.setInt(mutableAggBufferOffset, returned.getInt(0))
updateBuffer(buffer, returned)
}
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
buffer1.setLong(
mutableAggBufferOffset,
buffer1.getLong(mutableAggBufferOffset) + buffer2.getLong(inputAggBufferOffset))
val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1)
val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2)
val merged = aggregator.merge(b1, b2)
val returned = boundB.toRow(merged)
updateBuffer(buffer1, returned)
}
override def eval(buffer: InternalRow): Any = {
buffer.getInt(mutableAggBufferOffset)
val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer)
val result = cEncoder.toRow(aggregator.present(b))
dataType match {
case _: StructType => result
case _ => result.get(0, dataType)
}
}
override def toString: String = {
......
......@@ -57,6 +57,11 @@ abstract class Aggregator[-A, B, C] {
*/
def reduce(b: B, a: A): B
/**
* Merge two intermediate values
*/
def merge(b1: B, b2: B): B
/**
* Transform the output of the reduction.
*/
......
......@@ -34,9 +34,41 @@ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializ
override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
override def merge(b1: N, b2: N): N = numeric.plus(b1, b2)
override def present(reduction: N): N = reduction
}
object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] with Serializable {
override def zero: (Long, Long) = (0, 0)
override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = {
(countAndSum._1 + 1, countAndSum._2 + input._2)
}
override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = {
(b1._1 + b2._1, b1._2 + b2._2)
}
override def present(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1
}
object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)]
with Serializable {
override def zero: (Long, Long) = (0, 0)
override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = {
(countAndSum._1 + 1, countAndSum._2 + input._2)
}
override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = {
(b1._1 + b2._1, b1._2 + b2._2)
}
override def present(reduction: (Long, Long)): (Long, Long) = reduction
}
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
......@@ -62,4 +94,24 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
count("*")),
("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L))
}
test("typed aggregation: complex case") {
val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
checkAnswer(
ds.groupBy(_._1).agg(
expr("avg(_2)").as[Double],
TypedAverage.toColumn),
("a", 2.0, 2.0), ("b", 3.0, 3.0))
}
test("typed aggregation: complex result type") {
val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
checkAnswer(
ds.groupBy(_._1).agg(
expr("avg(_2)").as[Double],
ComplexResultAgg.toColumn),
("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L)))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment