diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 7bb8ef2595b59afa36d42bf8448b4993ab57cd60..356bcee3cf5c625ce64b4d5ddbb75a2d5772c063 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -215,7 +215,7 @@ setMethod("%in%", #' otherwise #' -#' If values in the specified column are null, returns the value. +#' If values in the specified column are null, returns the value. #' Can be used in conjunction with `when` to specify a default value for expressions. #' #' @rdname otherwise @@ -225,7 +225,7 @@ setMethod("%in%", setMethod("otherwise", signature(x = "Column", value = "ANY"), function(x, value) { - value <- ifelse(class(value) == "Column", value@jc, value) + value <- if (class(value) == "Column") { value@jc } else { value } jc <- callJMethod(x@jc, "otherwise", value) column(jc) }) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 09e4e04335a33c639089d116dfb0033c5853816d..df36bc869acb4df8b5c892bbf0074874acf6cc41 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -37,7 +37,7 @@ setMethod("lit", signature("ANY"), function(x) { jc <- callJStatic("org.apache.spark.sql.functions", "lit", - ifelse(class(x) == "Column", x@jc, x)) + if (class(x) == "Column") { x@jc } else { x }) column(jc) }) @@ -2262,7 +2262,7 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), setMethod("when", signature(condition = "Column", value = "ANY"), function(condition, value) { condition <- condition@jc - value <- ifelse(class(value) == "Column", value@jc, value) + value <- if (class(value) == "Column") { value@jc } else { value } jc <- callJStatic("org.apache.spark.sql.functions", "when", condition, value) column(jc) }) @@ -2277,13 +2277,16 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' @name ifelse #' @seealso \link{when} #' @export -#' @examples \dontrun{ifelse(df$a > 1 & df$b > 2, 0, 1)} +#' @examples \dontrun{ +#' ifelse(df$a > 1 & df$b > 2, 0, 1) +#' ifelse(df$a > 1, df$a, 1) +#' } setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), function(test, yes, no) { test <- test@jc - yes <- ifelse(class(yes) == "Column", yes@jc, yes) - no <- ifelse(class(no) == "Column", no@jc, no) + yes <- if (class(yes) == "Column") { yes@jc } else { yes } + no <- if (class(no) == "Column") { no@jc } else { no } jc <- callJMethod(callJStatic("org.apache.spark.sql.functions", "when", test, yes), diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 135c7576e5291fdaa87ed7c26a62f5f6b0208f07..c2b6adbe3ae0126ee88756fcf8a7b0178d107f1c 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1120,6 +1120,14 @@ test_that("when(), otherwise() and ifelse() on a DataFrame", { expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, 0, 1)))[, 1], c(1, 0)) }) +test_that("when(), otherwise() and ifelse() with column on a DataFrame", { + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, lit(1))))[, 1], c(NA, 1)) + expect_equal(collect(select(df, otherwise(when(df$a > 1, lit(1)), lit(0))))[, 1], c(0, 1)) + expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, lit(0), lit(1))))[, 1], c(1, 0)) +}) + test_that("group by, agg functions", { df <- read.json(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum")