diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index ead7bd9642ecabde6aad28ee48377d8a5119dc98..f9b4cd83c3a422af128b1f2cfd7286d198f54c80 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.expressions.scalalang.typed
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.StringType
 
 
 object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] {
@@ -52,6 +53,16 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] {
 }
 
 
+object ClassBufferAggregator extends Aggregator[AggData, AggData, Int] {
+  override def zero: AggData = AggData(0, "")
+  override def reduce(b: AggData, a: AggData): AggData = AggData(b.a + a.a, "")
+  override def finish(reduction: AggData): Int = reduction.a
+  override def merge(b1: AggData, b2: AggData): AggData = AggData(b1.a + b2.a, "")
+  override def bufferEncoder: Encoder[AggData] = Encoders.product[AggData]
+  override def outputEncoder: Encoder[Int] = Encoders.scalaInt
+}
+
+
 object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] {
   override def zero: (Int, AggData) = 0 -> AggData(0, "0")
   override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a)
@@ -173,6 +184,14 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
       ("one", 1))
   }
 
+  test("Typed aggregation using aggregator") {
+    // based on Dataset complex Aggregator test of DatasetBenchmark
+    val ds = Seq(AggData(1, "x"), AggData(2, "y"), AggData(3, "z")).toDS()
+    checkDataset(
+      ds.select(ClassBufferAggregator.toColumn),
+      6)
+  }
+
   test("typed aggregation: complex input") {
     val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()