From 12854464c4fa30c4df3b5b17bd8914d048dbf4a9 Mon Sep 17 00:00:00 2001
From: Wenchen Fan <wenchen@databricks.com>
Date: Sat, 16 Apr 2016 00:31:51 -0700
Subject: [PATCH] [SPARK-13363][SQL] support Aggregator in
 RelationalGroupedDataset

## What changes were proposed in this pull request?

set the input encoder for `TypedColumn` in `RelationalGroupedDataset.agg`.

## How was this patch tested?

new tests in `DatasetAggregatorSuite`

close https://github.com/apache/spark/pull/11269

Author: Wenchen Fan <wenchen@databricks.com>

Closes #12359 from cloud-fan/agg.
---
 .../spark/sql/RelationalGroupedDataset.scala       |  6 +++++-
 .../apache/spark/sql/DatasetAggregatorSuite.scala  | 14 +++++++++++++-
 2 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 7dbf2e6c7c..deb2e82165 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -208,7 +208,11 @@ class RelationalGroupedDataset protected[sql](
    */
   @scala.annotation.varargs
   def agg(expr: Column, exprs: Column*): DataFrame = {
-    toDF((expr +: exprs).map(_.expr))
+    toDF((expr +: exprs).map {
+      case typed: TypedColumn[_, _] =>
+        typed.withInputType(df.resolvedTEncoder, df.logicalPlan.output).expr
+      case c => c.expr
+    })
   }
 
   /**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 3a7215ee39..0d84a594f7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql
 
 import scala.language.postfixOps
 
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.expressions.scala.typed
 import org.apache.spark.sql.functions._
@@ -85,6 +84,15 @@ class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT)
   override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]]
 }
 
+object RowAgg extends Aggregator[Row, Int, Int] {
+  def zero: Int = 0
+  def reduce(b: Int, a: Row): Int = a.getInt(0) + b
+  def merge(b1: Int, b2: Int): Int = b1 + b2
+  def finish(r: Int): Int = r
+  override def bufferEncoder: Encoder[Int] = Encoders.scalaInt
+  override def outputEncoder: Encoder[Int] = Encoders.scalaInt
+}
+
 
 class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
 
@@ -200,4 +208,8 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
         (1279869254, "Some String"))
   }
 
+  test("aggregator in DataFrame/Dataset[Row]") {
+    val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j")
+    checkAnswer(df.groupBy($"j").agg(RowAgg.toColumn), Row("a", 1) :: Row("b", 5) :: Nil)
+  }
 }
-- 
GitLab