Skip to content
Snippets Groups Projects
Commit 12854464 authored by Wenchen Fan's avatar Wenchen Fan Committed by Reynold Xin
Browse files

[SPARK-13363][SQL] support Aggregator in RelationalGroupedDataset

## What changes were proposed in this pull request?

set the input encoder for `TypedColumn` in `RelationalGroupedDataset.agg`.

## How was this patch tested?

new tests in `DatasetAggregatorSuite`

close https://github.com/apache/spark/pull/11269

Author: Wenchen Fan <wenchen@databricks.com>

Closes #12359 from cloud-fan/agg.
parent f4be0946
No related branches found
No related tags found
No related merge requests found
...@@ -208,7 +208,11 @@ class RelationalGroupedDataset protected[sql]( ...@@ -208,7 +208,11 @@ class RelationalGroupedDataset protected[sql](
*/ */
@scala.annotation.varargs @scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame = { def agg(expr: Column, exprs: Column*): DataFrame = {
toDF((expr +: exprs).map(_.expr)) toDF((expr +: exprs).map {
case typed: TypedColumn[_, _] =>
typed.withInputType(df.resolvedTEncoder, df.logicalPlan.output).expr
case c => c.expr
})
} }
/** /**
......
...@@ -19,7 +19,6 @@ package org.apache.spark.sql ...@@ -19,7 +19,6 @@ package org.apache.spark.sql
import scala.language.postfixOps import scala.language.postfixOps
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.expressions.scala.typed
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
...@@ -85,6 +84,15 @@ class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT) ...@@ -85,6 +84,15 @@ class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT)
override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]] override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]]
} }
object RowAgg extends Aggregator[Row, Int, Int] {
def zero: Int = 0
def reduce(b: Int, a: Row): Int = a.getInt(0) + b
def merge(b1: Int, b2: Int): Int = b1 + b2
def finish(r: Int): Int = r
override def bufferEncoder: Encoder[Int] = Encoders.scalaInt
override def outputEncoder: Encoder[Int] = Encoders.scalaInt
}
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
...@@ -200,4 +208,8 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ...@@ -200,4 +208,8 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
(1279869254, "Some String")) (1279869254, "Some String"))
} }
test("aggregator in DataFrame/Dataset[Row]") {
val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j")
checkAnswer(df.groupBy($"j").agg(RowAgg.toColumn), Row("a", 1) :: Row("b", 5) :: Nil)
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment