From dfcfcbcc0448ebc6f02eba6bf0495832a321c87e Mon Sep 17 00:00:00 2001 From: Wenchen Fan <wenchen@databricks.com> Date: Tue, 10 Nov 2015 11:14:25 -0800 Subject: [PATCH] [SPARK-11578][SQL][FOLLOW-UP] complete the user facing api for typed aggregation Currently the user facing api for typed aggregation has some limitations: * the customized typed aggregation must be the first of aggregation list * the customized typed aggregation can only use long as buffer type * the customized typed aggregation can only use flat type as result type This PR tries to remove these limitations. Author: Wenchen Fan <wenchen@databricks.com> Closes #9599 from cloud-fan/agg. --- .../catalyst/encoders/ExpressionEncoder.scala | 6 +++ .../aggregate/TypedAggregateExpression.scala | 50 +++++++++++++----- .../spark/sql/expressions/Aggregator.scala | 5 ++ .../spark/sql/DatasetAggregatorSuite.scala | 52 +++++++++++++++++++ 4 files changed, 99 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index c287aebeee..005c0627f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -185,6 +185,12 @@ case class ExpressionEncoder[T]( }) } + def shift(delta: Int): ExpressionEncoder[T] = { + copy(constructExpression = constructExpression transform { + case r: BoundReference => r.copy(ordinal = r.ordinal + delta) + }) + } + /** * Returns a copy of this encoder where the expressions used to create an object given an * input row have been modified to pull the object out from a nested struct, instead of the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 24d8122b62..0e5bc1f9ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials import org.apache.spark.Logging +import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate -import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StructType, DataType} +import org.apache.spark.sql.types._ object TypedAggregateExpression { def apply[A, B : Encoder, C : Encoder]( @@ -67,8 +67,11 @@ case class TypedAggregateExpression( override def nullable: Boolean = true - // TODO: this assumes flat results... - override def dataType: DataType = cEncoder.schema.head.dataType + override def dataType: DataType = if (cEncoder.flat) { + cEncoder.schema.head.dataType + } else { + cEncoder.schema + } override def deterministic: Boolean = true @@ -93,32 +96,51 @@ case class TypedAggregateExpression( case a: AttributeReference => inputMapping(a) }) - // TODO: this probably only works when we are in the first column. val bAttributes = bEncoder.schema.toAttributes lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) + private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { + // todo: need a more neat way to assign the value. + var i = 0 + while (i < aggBufferAttributes.length) { + aggBufferSchema(i).dataType match { + case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i)) + case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i)) + } + i += 1 + } + } + override def initialize(buffer: MutableRow): Unit = { - // TODO: We need to either force Aggregator to have a zero or we need to eliminate the need for - // this in execution. - buffer.setInt(mutableAggBufferOffset, aggregator.zero.asInstanceOf[Int]) + val zero = bEncoder.toRow(aggregator.zero) + updateBuffer(buffer, zero) } override def update(buffer: MutableRow, input: InternalRow): Unit = { val inputA = boundA.fromRow(input) - val currentB = boundB.fromRow(buffer) + val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer) val merged = aggregator.reduce(currentB, inputA) val returned = boundB.toRow(merged) - buffer.setInt(mutableAggBufferOffset, returned.getInt(0)) + + updateBuffer(buffer, returned) } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - buffer1.setLong( - mutableAggBufferOffset, - buffer1.getLong(mutableAggBufferOffset) + buffer2.getLong(inputAggBufferOffset)) + val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1) + val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2) + val merged = aggregator.merge(b1, b2) + val returned = boundB.toRow(merged) + + updateBuffer(buffer1, returned) } override def eval(buffer: InternalRow): Any = { - buffer.getInt(mutableAggBufferOffset) + val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val result = cEncoder.toRow(aggregator.present(b)) + dataType match { + case _: StructType => result + case _ => result.get(0, dataType) + } } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 8cc25c2440..3c1c457e06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -57,6 +57,11 @@ abstract class Aggregator[-A, B, C] { */ def reduce(b: B, a: A): B + /** + * Merge two intermediate values + */ + def merge(b1: B, b2: B): B + /** * Transform the output of the reduction. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 340470c096..206095a519 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -34,9 +34,41 @@ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializ override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + override def merge(b1: N, b2: N): N = numeric.plus(b1, b2) + override def present(reduction: N): N = reduction } +object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] with Serializable { + override def zero: (Long, Long) = (0, 0) + + override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { + (countAndSum._1 + 1, countAndSum._2 + input._2) + } + + override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + + override def present(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1 +} + +object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] + with Serializable { + + override def zero: (Long, Long) = (0, 0) + + override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { + (countAndSum._1 + 1, countAndSum._2 + input._2) + } + + override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + + override def present(reduction: (Long, Long)): (Long, Long) = reduction +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -62,4 +94,24 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { count("*")), ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) } + + test("typed aggregation: complex case") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + expr("avg(_2)").as[Double], + TypedAverage.toColumn), + ("a", 2.0, 2.0), ("b", 3.0, 3.0)) + } + + test("typed aggregation: complex result type") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + expr("avg(_2)").as[Double], + ComplexResultAgg.toColumn), + ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) + } } -- GitLab