diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index e5521f3cffadf0384bf4a4a6ba5dc6b22c3f060d..d9c10b4a4b9fb448f5dd11213debee1e00436e4b 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -536,15 +536,27 @@ setMethod("factorial", #' #' Aggregate function: returns the first value in a group. #' +#' The function by default returns the first values it sees. It will return the first non-missing +#' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' #' @rdname first #' @name first #' @family agg_funcs #' @export -#' @examples \dontrun{first(df$c)} +#' @examples +#' \dontrun{ +#' first(df$c) +#' first(df$c, TRUE) +#' } setMethod("first", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "first", x@jc) + signature(x = "characterOrColumn"), + function(x, na.rm = FALSE) { + col <- if (class(x) == "Column") { + x@jc + } else { + x + } + jc <- callJStatic("org.apache.spark.sql.functions", "first", col, na.rm) column(jc) }) @@ -663,15 +675,27 @@ setMethod("kurtosis", #' #' Aggregate function: returns the last value in a group. #' +#' The function by default returns the last values it sees. It will return the last non-missing +#' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' #' @rdname last #' @name last #' @family agg_funcs #' @export -#' @examples \dontrun{last(df$c)} +#' @examples +#' \dontrun{ +#' last(df$c) +#' last(df$c, TRUE) +#' } setMethod("last", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "last", x@jc) + signature(x = "characterOrColumn"), + function(x, na.rm = FALSE) { + col <- if (class(x) == "Column") { + x@jc + } else { + x + } + jc <- callJStatic("org.apache.spark.sql.functions", "last", col, na.rm) column(jc) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 3db72b57954d76bc57dc644f3f75a095da79bc83..ddfa61717af2e4f44ce69ef7bd8047c2695ad9ad 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -84,7 +84,7 @@ setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) # @rdname first # @export -setGeneric("first", function(x) { standardGeneric("first") }) +setGeneric("first", function(x, ...) { standardGeneric("first") }) # @rdname flatMap # @export @@ -889,7 +889,7 @@ setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last #' @export -setGeneric("last", function(x) { standardGeneric("last") }) +setGeneric("last", function(x, ...) { standardGeneric("last") }) #' @rdname last_day #' @export diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index cad5766812aed86d9df45accb0312e1789c77640..11a8f12fd54321295a7e5a74ef8be3e35d1c3752 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1076,6 +1076,17 @@ test_that("column functions", { result <- collect(select(df, encode(df$a, "utf-8"), decode(df$c, "utf-8"))) expect_equal(result[[1]][[1]], bytes) expect_equal(result[[2]], markUtf8("大åƒä¸–界")) + + # Test first(), last() + df <- read.json(sqlContext, jsonPath) + expect_equal(collect(select(df, first(df$age)))[[1]], NA) + expect_equal(collect(select(df, first(df$age, TRUE)))[[1]], 30) + expect_equal(collect(select(df, first("age")))[[1]], NA) + expect_equal(collect(select(df, first("age", TRUE)))[[1]], 30) + expect_equal(collect(select(df, last(df$age)))[[1]], 19) + expect_equal(collect(select(df, last(df$age, TRUE)))[[1]], 19) + expect_equal(collect(select(df, last("age")))[[1]], 19) + expect_equal(collect(select(df, last("age", TRUE)))[[1]], 19) }) test_that("column binary mathfunctions", {