Skip to content
Snippets Groups Projects
Commit 9560c8d2 authored by petermaxlee's avatar petermaxlee Committed by Wenchen Fan
Browse files

[SPARK-17124][SQL] RelationalGroupedDataset.agg should preserve order and...

[SPARK-17124][SQL] RelationalGroupedDataset.agg should preserve order and allow multiple aggregates per column

## What changes were proposed in this pull request?
This patch fixes a longstanding issue with one of the RelationalGroupedDataset.agg function. Even though the signature accepts vararg of pairs, the underlying implementation turns the seq into a map, and thus not order preserving nor allowing multiple aggregates per column.

This change also allows users to use this function to run multiple different aggregations for a single column, e.g.
```
agg("age" -> "max", "age" -> "count")
```

## How was this patch tested?
Added a test case in DataFrameAggregateSuite.

Author: petermaxlee <petermaxlee@gmail.com>

Closes #14697 from petermaxlee/SPARK-17124.
parent 31a01557
No related branches found
No related tags found
No related merge requests found
......@@ -128,7 +128,7 @@ class RelationalGroupedDataset protected[sql](
}
/**
* (Scala-specific) Compute aggregates by specifying a map from column name to
* (Scala-specific) Compute aggregates by specifying the column names and
* aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns.
*
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
......@@ -143,7 +143,9 @@ class RelationalGroupedDataset protected[sql](
* @since 1.3.0
*/
def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = {
agg((aggExpr +: aggExprs).toMap)
toDF((aggExpr +: aggExprs).map { case (colName, expr) =>
strToExpr(expr)(df(colName).expr)
})
}
/**
......
......@@ -87,6 +87,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}
test("SPARK-17124 agg should be ordering preserving") {
val df = spark.range(2)
val ret = df.groupBy("id").agg("id" -> "sum", "id" -> "count", "id" -> "min")
assert(ret.schema.map(_.name) == Seq("id", "sum(id)", "count(id)", "min(id)"))
checkAnswer(
ret,
Row(0, 0, 1, 0) :: Row(1, 1, 1, 1) :: Nil
)
}
test("rollup") {
checkAnswer(
courseSales.rollup("course", "year").sum("earnings"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment