diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index ead7bd9642ecabde6aad28ee48377d8a5119dc98..f9b4cd83c3a422af128b1f2cfd7286d198f54c80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StringType object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { @@ -52,6 +53,16 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] { } +object ClassBufferAggregator extends Aggregator[AggData, AggData, Int] { + override def zero: AggData = AggData(0, "") + override def reduce(b: AggData, a: AggData): AggData = AggData(b.a + a.a, "") + override def finish(reduction: AggData): Int = reduction.a + override def merge(b1: AggData, b2: AggData): AggData = AggData(b1.a + b2.a, "") + override def bufferEncoder: Encoder[AggData] = Encoders.product[AggData] + override def outputEncoder: Encoder[Int] = Encoders.scalaInt +} + + object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { override def zero: (Int, AggData) = 0 -> AggData(0, "0") override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) @@ -173,6 +184,14 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ("one", 1)) } + test("Typed aggregation using aggregator") { + // based on Dataset complex Aggregator test of DatasetBenchmark + val ds = Seq(AggData(1, "x"), AggData(2, "y"), AggData(3, "z")).toDS() + checkDataset( + ds.select(ClassBufferAggregator.toColumn), + 6) + } + test("typed aggregation: complex input") { val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()