Skip to content
Snippets Groups Projects
Commit 946b4065 authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-11913][SQL] support typed aggregate with complex buffer schema

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9898 from cloud-fan/agg.
parent f2996e0d
No related branches found
No related tags found
No related merge requests found
......@@ -23,9 +23,8 @@ import org.apache.spark.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
......@@ -46,14 +45,12 @@ object TypedAggregateExpression {
/**
* This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has
* the following limitations:
* - It assumes the aggregator reduces and returns a single column of type `long`.
* - It might only work when there is a single aggregator in the first column.
* - It assumes the aggregator has a zero, `0`.
*/
case class TypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any],
aEncoder: Option[ExpressionEncoder[Any]], // Should be bound.
bEncoder: ExpressionEncoder[Any], // Should be bound.
unresolvedBEncoder: ExpressionEncoder[Any],
cEncoder: ExpressionEncoder[Any],
children: Seq[Attribute],
mutableAggBufferOffset: Int,
......@@ -80,10 +77,14 @@ case class TypedAggregateExpression(
override lazy val inputTypes: Seq[DataType] = Nil
override val aggBufferSchema: StructType = bEncoder.schema
override val aggBufferSchema: StructType = unresolvedBEncoder.schema
override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes
val bEncoder = unresolvedBEncoder
.resolve(aggBufferAttributes, OuterScopes.outerScopes)
.bind(aggBufferAttributes)
// Note: although this simply copies aggBufferAttributes, this common code can not be placed
// in the superclass because that will lead to initialization ordering issues.
override val inputAggBufferAttributes: Seq[AttributeReference] =
......@@ -93,12 +94,18 @@ case class TypedAggregateExpression(
lazy val boundA = aEncoder.get
private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
// todo: need a more neat way to assign the value.
var i = 0
while (i < aggBufferAttributes.length) {
val offset = mutableAggBufferOffset + i
aggBufferSchema(i).dataType match {
case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i))
case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i))
case BooleanType => buffer.setBoolean(offset, value.getBoolean(i))
case ByteType => buffer.setByte(offset, value.getByte(i))
case ShortType => buffer.setShort(offset, value.getShort(i))
case IntegerType => buffer.setInt(offset, value.getInt(i))
case LongType => buffer.setLong(offset, value.getLong(i))
case FloatType => buffer.setFloat(offset, value.getFloat(i))
case DoubleType => buffer.setDouble(offset, value.getDouble(i))
case other => buffer.update(offset, value.get(i, other))
}
i += 1
}
......
......@@ -67,7 +67,7 @@ object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, L
}
case class AggData(a: Int, b: String)
object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable {
object ClassInputAgg extends Aggregator[AggData, Int, Int] {
/** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
override def zero: Int = 0
......@@ -88,6 +88,28 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable {
override def merge(b1: Int, b2: Int): Int = b1 + b2
}
object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] {
/** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
override def zero: (Int, AggData) = 0 -> AggData(0, "0")
/**
* Combine two values to produce a new value. For performance, the function may modify `b` and
* return it instead of constructing new object for b.
*/
override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a)
/**
* Transform the output of the reduction.
*/
override def finish(reduction: (Int, AggData)): Int = reduction._1
/**
* Merge two intermediate values
*/
override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) =
(b1._1 + b2._1, b1._2)
}
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
......@@ -168,4 +190,21 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
ds.groupBy(_.b).agg(ClassInputAgg.toColumn),
("one", 1))
}
test("typed aggregation: complex input") {
val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()
checkAnswer(
ds.select(ComplexBufferAgg.toColumn),
2
)
checkAnswer(
ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn),
(1.5, 2))
checkAnswer(
ds.groupBy(_.b).agg(ComplexBufferAgg.toColumn),
("one", 1), ("two", 1))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment