From 90d77e971f6b3fa268e411279f34bc1db4321991 Mon Sep 17 00:00:00 2001
From: zero323 <zero323@users.noreply.github.com>
Date: Mon, 1 May 2017 21:39:17 -0700
Subject: [PATCH] [SPARK-20532][SPARKR] Implement grouping and grouping_id

## What changes were proposed in this pull request?

Adds R wrappers for:

- `o.a.s.sql.functions.grouping` as `o.a.s.sql.functions.is_grouping` (to avoid shading `base::grouping`
- `o.a.s.sql.functions.grouping_id`

## How was this patch tested?

Existing unit tests, additional unit tests. `check-cran.sh`.

Author: zero323 <zero323@users.noreply.github.com>

Closes #17807 from zero323/SPARK-20532.
---
 R/pkg/NAMESPACE                           |  2 +
 R/pkg/R/functions.R                       | 84 +++++++++++++++++++++++
 R/pkg/R/generics.R                        |  8 +++
 R/pkg/inst/tests/testthat/test_sparkSQL.R | 56 ++++++++++++++-
 4 files changed, 148 insertions(+), 2 deletions(-)

diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index e8de34d937..7ecd168137 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -249,6 +249,8 @@ exportMethods("%<=>%",
               "getField",
               "getItem",
               "greatest",
+              "grouping_bit",
+              "grouping_id",
               "hex",
               "histogram",
               "hour",
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index f9687d680e..38384a8991 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -3890,3 +3890,87 @@ setMethod("not",
             jc <- callJStatic("org.apache.spark.sql.functions", "not", x@jc)
             column(jc)
           })
+
+#' grouping_bit
+#'
+#' Indicates whether a specified column in a GROUP BY list is aggregated or not,
+#' returns 1 for aggregated or 0 for not aggregated in the result set.
+#'
+#' Same as \code{GROUPING} in SQL and \code{grouping} function in Scala.
+#'
+#' @param x Column to compute on
+#'
+#' @rdname grouping_bit
+#' @name grouping_bit
+#' @family agg_funcs
+#' @aliases grouping_bit,Column-method
+#' @export
+#' @examples \dontrun{
+#' df <- createDataFrame(mtcars)
+#'
+#' # With cube
+#' agg(
+#'   cube(df, "cyl", "gear", "am"),
+#'   mean(df$mpg),
+#'   grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am)
+#' )
+#'
+#' # With rollup
+#' agg(
+#'   rollup(df, "cyl", "gear", "am"),
+#'   mean(df$mpg),
+#'   grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am)
+#' )
+#' }
+#' @note grouping_bit since 2.3.0
+setMethod("grouping_bit",
+          signature(x = "Column"),
+          function(x) {
+            jc <- callJStatic("org.apache.spark.sql.functions", "grouping", x@jc)
+            column(jc)
+          })
+
+#' grouping_id
+#'
+#' Returns the level of grouping.
+#'
+#' Equals to \code{
+#' grouping_bit(c1) * 2^(n - 1) + grouping_bit(c2) * 2^(n - 2)  + ... + grouping_bit(cn)
+#' }
+#'
+#' @param x Column to compute on
+#' @param ... additional Column(s) (optional).
+#'
+#' @rdname grouping_id
+#' @name grouping_id
+#' @family agg_funcs
+#' @aliases grouping_id,Column-method
+#' @export
+#' @examples \dontrun{
+#' df <- createDataFrame(mtcars)
+#'
+#' # With cube
+#' agg(
+#'   cube(df, "cyl", "gear", "am"),
+#'   mean(df$mpg),
+#'   grouping_id(df$cyl, df$gear, df$am)
+#' )
+#'
+#' # With rollup
+#' agg(
+#'   rollup(df, "cyl", "gear", "am"),
+#'   mean(df$mpg),
+#'   grouping_id(df$cyl, df$gear, df$am)
+#' )
+#' }
+#' @note grouping_id since 2.3.0
+setMethod("grouping_id",
+          signature(x = "Column"),
+          function(x, ...) {
+            jcols <- lapply(list(x, ...), function (x) {
+              stopifnot(class(x) == "Column")
+              x@jc
+            })
+            jc <- callJStatic("org.apache.spark.sql.functions", "grouping_id", jcols)
+            column(jc)
+          })
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index ef36765a7a..e02d46426a 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1052,6 +1052,14 @@ setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime")
 #' @export
 setGeneric("greatest", function(x, ...) { standardGeneric("greatest") })
 
+#' @rdname grouping_bit
+#' @export
+setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") })
+
+#' @rdname grouping_id
+#' @export
+setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") })
+
 #' @rdname hex
 #' @export
 setGeneric("hex", function(x) { standardGeneric("hex") })
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 08296354ca..12867c15d1 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1848,7 +1848,11 @@ test_that("test multi-dimensional aggregations with cube and rollup", {
     orderBy(
       agg(
         cube(df, "year", "department"),
-        expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary")
+        expr("sum(salary) AS total_salary"),
+        expr("avg(salary) AS average_salary"),
+        alias(grouping_bit(df$year), "grouping_year"),
+        alias(grouping_bit(df$department), "grouping_department"),
+        alias(grouping_id(df$year, df$department), "grouping_id")
       ),
       "year", "department"
     )
@@ -1875,6 +1879,30 @@ test_that("test multi-dimensional aggregations with cube and rollup", {
       mean(c(21000, 32000, 22000)), # 2017
       22000, 32000, 21000 # 2017 each department
     ),
+    grouping_year = c(
+      1, # global
+      1, 1, 1, # by department
+      0, # 2016
+      0, 0, 0, # 2016 by department
+      0, # 2017
+      0, 0, 0 # 2017 by department
+    ),
+    grouping_department = c(
+      1, # global
+      0, 0, 0, # by department
+      1, # 2016
+      0, 0, 0, # 2016 by department
+      1, # 2017
+      0, 0, 0 # 2017 by department
+    ),
+    grouping_id = c(
+      3, #  11
+      2, 2, 2, # 10
+      1, # 01
+      0, 0, 0, # 00
+      1, # 01
+      0, 0, 0 # 00
+    ),
     stringsAsFactors = FALSE
   )
 
@@ -1896,7 +1924,10 @@ test_that("test multi-dimensional aggregations with cube and rollup", {
     orderBy(
       agg(
         rollup(df, "year", "department"),
-        expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary")
+        expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary"),
+        alias(grouping_bit(df$year), "grouping_year"),
+        alias(grouping_bit(df$department), "grouping_department"),
+        alias(grouping_id(df$year, df$department), "grouping_id")
       ),
       "year", "department"
     )
@@ -1920,6 +1951,27 @@ test_that("test multi-dimensional aggregations with cube and rollup", {
       mean(c(21000, 32000, 22000)), # 2017
       22000, 32000, 21000 # 2017 each department
     ),
+    grouping_year = c(
+      1, # global
+      0, # 2016
+      0, 0, 0, # 2016 each department
+      0, # 2017
+      0, 0, 0 # 2017 each department
+    ),
+    grouping_department = c(
+      1, # global
+      1, # 2016
+      0, 0, 0, # 2016 each department
+      1, # 2017
+      0, 0, 0 # 2017 each department
+    ),
+    grouping_id = c(
+      3, # 11
+      1, # 01
+      0, 0, 0, # 00
+      1, # 01
+      0, 0, 0 # 00
+    ),
     stringsAsFactors = FALSE
   )
 
-- 
GitLab