From 90d77e971f6b3fa268e411279f34bc1db4321991 Mon Sep 17 00:00:00 2001 From: zero323 <zero323@users.noreply.github.com> Date: Mon, 1 May 2017 21:39:17 -0700 Subject: [PATCH] [SPARK-20532][SPARKR] Implement grouping and grouping_id ## What changes were proposed in this pull request? Adds R wrappers for: - `o.a.s.sql.functions.grouping` as `o.a.s.sql.functions.is_grouping` (to avoid shading `base::grouping` - `o.a.s.sql.functions.grouping_id` ## How was this patch tested? Existing unit tests, additional unit tests. `check-cran.sh`. Author: zero323 <zero323@users.noreply.github.com> Closes #17807 from zero323/SPARK-20532. --- R/pkg/NAMESPACE | 2 + R/pkg/R/functions.R | 84 +++++++++++++++++++++++ R/pkg/R/generics.R | 8 +++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 56 ++++++++++++++- 4 files changed, 148 insertions(+), 2 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index e8de34d937..7ecd168137 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -249,6 +249,8 @@ exportMethods("%<=>%", "getField", "getItem", "greatest", + "grouping_bit", + "grouping_id", "hex", "histogram", "hour", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f9687d680e..38384a8991 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3890,3 +3890,87 @@ setMethod("not", jc <- callJStatic("org.apache.spark.sql.functions", "not", x@jc) column(jc) }) + +#' grouping_bit +#' +#' Indicates whether a specified column in a GROUP BY list is aggregated or not, +#' returns 1 for aggregated or 0 for not aggregated in the result set. +#' +#' Same as \code{GROUPING} in SQL and \code{grouping} function in Scala. +#' +#' @param x Column to compute on +#' +#' @rdname grouping_bit +#' @name grouping_bit +#' @family agg_funcs +#' @aliases grouping_bit,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # With cube +#' agg( +#' cube(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am) +#' ) +#' +#' # With rollup +#' agg( +#' rollup(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am) +#' ) +#' } +#' @note grouping_bit since 2.3.0 +setMethod("grouping_bit", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "grouping", x@jc) + column(jc) + }) + +#' grouping_id +#' +#' Returns the level of grouping. +#' +#' Equals to \code{ +#' grouping_bit(c1) * 2^(n - 1) + grouping_bit(c2) * 2^(n - 2) + ... + grouping_bit(cn) +#' } +#' +#' @param x Column to compute on +#' @param ... additional Column(s) (optional). +#' +#' @rdname grouping_id +#' @name grouping_id +#' @family agg_funcs +#' @aliases grouping_id,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # With cube +#' agg( +#' cube(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_id(df$cyl, df$gear, df$am) +#' ) +#' +#' # With rollup +#' agg( +#' rollup(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_id(df$cyl, df$gear, df$am) +#' ) +#' } +#' @note grouping_id since 2.3.0 +setMethod("grouping_id", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "grouping_id", jcols) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ef36765a7a..e02d46426a 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1052,6 +1052,14 @@ setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") #' @export setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) +#' @rdname grouping_bit +#' @export +setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) + +#' @rdname grouping_id +#' @export +setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) + #' @rdname hex #' @export setGeneric("hex", function(x) { standardGeneric("hex") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 08296354ca..12867c15d1 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1848,7 +1848,11 @@ test_that("test multi-dimensional aggregations with cube and rollup", { orderBy( agg( cube(df, "year", "department"), - expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary") + expr("sum(salary) AS total_salary"), + expr("avg(salary) AS average_salary"), + alias(grouping_bit(df$year), "grouping_year"), + alias(grouping_bit(df$department), "grouping_department"), + alias(grouping_id(df$year, df$department), "grouping_id") ), "year", "department" ) @@ -1875,6 +1879,30 @@ test_that("test multi-dimensional aggregations with cube and rollup", { mean(c(21000, 32000, 22000)), # 2017 22000, 32000, 21000 # 2017 each department ), + grouping_year = c( + 1, # global + 1, 1, 1, # by department + 0, # 2016 + 0, 0, 0, # 2016 by department + 0, # 2017 + 0, 0, 0 # 2017 by department + ), + grouping_department = c( + 1, # global + 0, 0, 0, # by department + 1, # 2016 + 0, 0, 0, # 2016 by department + 1, # 2017 + 0, 0, 0 # 2017 by department + ), + grouping_id = c( + 3, # 11 + 2, 2, 2, # 10 + 1, # 01 + 0, 0, 0, # 00 + 1, # 01 + 0, 0, 0 # 00 + ), stringsAsFactors = FALSE ) @@ -1896,7 +1924,10 @@ test_that("test multi-dimensional aggregations with cube and rollup", { orderBy( agg( rollup(df, "year", "department"), - expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary") + expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary"), + alias(grouping_bit(df$year), "grouping_year"), + alias(grouping_bit(df$department), "grouping_department"), + alias(grouping_id(df$year, df$department), "grouping_id") ), "year", "department" ) @@ -1920,6 +1951,27 @@ test_that("test multi-dimensional aggregations with cube and rollup", { mean(c(21000, 32000, 22000)), # 2017 22000, 32000, 21000 # 2017 each department ), + grouping_year = c( + 1, # global + 0, # 2016 + 0, 0, 0, # 2016 each department + 0, # 2017 + 0, 0, 0 # 2017 each department + ), + grouping_department = c( + 1, # global + 1, # 2016 + 0, 0, 0, # 2016 each department + 1, # 2017 + 0, 0, 0 # 2017 each department + ), + grouping_id = c( + 3, # 11 + 1, # 01 + 0, 0, 0, # 00 + 1, # 01 + 0, 0, 0 # 00 + ), stringsAsFactors = FALSE ) -- GitLab