Skip to content
Snippets Groups Projects
Commit 79268aa4 authored by Hiroshi Inoue's avatar Hiroshi Inoue Committed by Wenchen Fan
Browse files

[SPARK-15704][SQL] add a test case in DatasetAggregatorSuite for regression testing

## What changes were proposed in this pull request?

This change fixes a crash in TungstenAggregate while executing "Dataset complex Aggregator" test case due to IndexOutOfBoundsException.

jira entry for detail: https://issues.apache.org/jira/browse/SPARK-15704

## How was this patch tested?
Using existing unit tests (including DatasetBenchmark)

Author: Hiroshi Inoue <inouehrs@jp.ibm.com>

Closes #13446 from inouehrs/fix_aggregate.
parent 26c1089c
No related branches found
No related tags found
No related merge requests found
...@@ -24,6 +24,7 @@ import org.apache.spark.sql.expressions.Aggregator ...@@ -24,6 +24,7 @@ import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StringType
object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] {
...@@ -52,6 +53,16 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] { ...@@ -52,6 +53,16 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] {
} }
object ClassBufferAggregator extends Aggregator[AggData, AggData, Int] {
override def zero: AggData = AggData(0, "")
override def reduce(b: AggData, a: AggData): AggData = AggData(b.a + a.a, "")
override def finish(reduction: AggData): Int = reduction.a
override def merge(b1: AggData, b2: AggData): AggData = AggData(b1.a + b2.a, "")
override def bufferEncoder: Encoder[AggData] = Encoders.product[AggData]
override def outputEncoder: Encoder[Int] = Encoders.scalaInt
}
object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] {
override def zero: (Int, AggData) = 0 -> AggData(0, "0") override def zero: (Int, AggData) = 0 -> AggData(0, "0")
override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a)
...@@ -173,6 +184,14 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ...@@ -173,6 +184,14 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
("one", 1)) ("one", 1))
} }
test("Typed aggregation using aggregator") {
// based on Dataset complex Aggregator test of DatasetBenchmark
val ds = Seq(AggData(1, "x"), AggData(2, "y"), AggData(3, "z")).toDS()
checkDataset(
ds.select(ClassBufferAggregator.toColumn),
6)
}
test("typed aggregation: complex input") { test("typed aggregation: complex input") {
val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment