diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index b6b559adf06ea0442a4ef25b439ba6335bf6c3cc..e804e30e14b86076983bbcf5ce03fa5516df6213 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -203,6 +203,8 @@ exportMethods("%in%", "cbrt", "ceil", "ceiling", + "collect_list", + "collect_set", "column", "concat", "concat_ws", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f854df11e576902e35cf959acca6429fa8ed49b4..e7decb91867bd59db5d6fbbbe6de5349c5765347 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3705,3 +3705,43 @@ setMethod("create_map", jc <- callJStatic("org.apache.spark.sql.functions", "map", jcols) column(jc) }) + +#' collect_list +#' +#' Creates a list of objects with duplicates. +#' +#' @param x Column to compute on +#' +#' @rdname collect_list +#' @name collect_list +#' @family agg_funcs +#' @aliases collect_list,Column-method +#' @export +#' @examples \dontrun{collect_list(df$x)} +#' @note collect_list since 2.3.0 +setMethod("collect_list", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "collect_list", x@jc) + column(jc) + }) + +#' collect_set +#' +#' Creates a list of objects with duplicate elements eliminated. +#' +#' @param x Column to compute on +#' +#' @rdname collect_set +#' @name collect_set +#' @family agg_funcs +#' @aliases collect_set,Column-method +#' @export +#' @examples \dontrun{collect_set(df$x)} +#' @note collect_set since 2.3.0 +setMethod("collect_set", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "collect_set", x@jc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index da46823f52a17d10954b1f0ce668c061ae97db2f..61d248ebd2e3e3741cc16cddd98e439df3e3c40d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -918,6 +918,14 @@ setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @export setGeneric("ceil", function(x) { standardGeneric("ceil") }) +#' @rdname collect_list +#' @export +setGeneric("collect_list", function(x) { standardGeneric("collect_list") }) + +#' @rdname collect_set +#' @export +setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) + #' @rdname column #' @export setGeneric("column", function(x) { standardGeneric("column") }) @@ -1358,6 +1366,7 @@ setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @export setGeneric("year", function(x) { standardGeneric("year") }) + ###################### Spark.ML Methods ########################## #' @rdname fitted diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9e87a47106994e7cbc1bb28a682547c68c6a1fcf..bf2093fdc475af567c4f0783ca24d3a542a94bee 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1731,6 +1731,28 @@ test_that("group by, agg functions", { expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6) + # Test collect_list and collect_set + gd3_collections_local <- collect( + agg(gd3, collect_set(df8$age), collect_list(df8$age)) + ) + + expect_equal( + unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 2]), + c(30) + ) + + expect_equal( + unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 3]), + c(30, 30) + ) + + expect_equal( + sort(unlist( + gd3_collections_local[gd3_collections_local$name == "Justin", 3] + )), + c(1, 19) + ) + unlink(jsonPath2) unlink(jsonPath3) })