diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ba29614e7b17942975b3440a97d15063b58e114e..64ffdcffc9cafdc1f4b4756dfde0a52d4d2a7949 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -59,33 +59,56 @@ exportMethods("arrange", exportClasses("Column") exportMethods("abs", + "acos", "alias", "approxCountDistinct", "asc", + "asin", + "atan", + "atan2", "avg", "cast", + "cbrt", + "ceiling", "contains", + "cos", + "cosh", "countDistinct", "desc", "endsWith", + "exp", + "expm1", + "floor", "getField", "getItem", + "hypot", "isNotNull", "isNull", "last", "like", + "log", + "log10", + "log1p", "lower", "max", "mean", "min", "n", "n_distinct", + "rint", "rlike", + "sign", + "sin", + "sinh", "sqrt", "startsWith", "substr", "sum", "sumDistinct", + "tan", + "tanh", + "toDegrees", + "toRadians", "upper") exportClasses("GroupedData") diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 9a68445ab451abfcde8d542cc0bbda48f4ed2360..80e92d3105a36d23c34909abaf2b106b91fc2746 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -55,12 +55,17 @@ operators <- list( "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", # we can not override `&&` and `||`, so use `&` and `|` instead - "&" = "and", "|" = "or" #, "!" = "unary_$bang" + "&" = "and", "|" = "or", #, "!" = "unary_$bang" + "^" = "pow" ) column_functions1 <- c("asc", "desc", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", - "first", "last", "lower", "upper", "sumDistinct") + "first", "last", "lower", "upper", "sumDistinct", + "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp", + "expm1", "floor", "log", "log10", "log1p", "rint", "sign", + "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians") +binary_mathfunctions<- c("atan2", "hypot") createOperator <- function(op) { setMethod(op, @@ -76,7 +81,11 @@ createOperator <- function(op) { if (class(e2) == "Column") { e2 <- e2@jc } - callJMethod(e1@jc, operators[[op]], e2) + if (op == "^") { + jc <- callJStatic("org.apache.spark.sql.functions", operators[[op]], e1@jc, e2) + } else { + callJMethod(e1@jc, operators[[op]], e2) + } } column(jc) }) @@ -106,11 +115,29 @@ createStaticFunction <- function(name) { setMethod(name, signature(x = "Column"), function(x) { + if (name == "ceiling") { + name <- "ceil" + } + if (name == "sign") { + name <- "signum" + } jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) column(jc) }) } +createBinaryMathfunctions <- function(name) { + setMethod(name, + signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) + column(jc) + }) +} + createMethods <- function() { for (op in names(operators)) { createOperator(op) @@ -124,6 +151,9 @@ createMethods <- function() { for (x in functions) { createStaticFunction(x) } + for (name in binary_mathfunctions) { + createBinaryMathfunctions(name) + } } createMethods() diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 6d2bfb1181e5a4aaf5b9adb846340f030b3b51e8..a23d3b217b2fd7457293e031f3d2891c72c0dadd 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -552,6 +552,10 @@ setGeneric("avg", function(x, ...) { standardGeneric("avg") }) #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) +#' @rdname column +#' @export +setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) + #' @rdname column #' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) @@ -575,6 +579,10 @@ setGeneric("getField", function(x, ...) { standardGeneric("getField") }) #' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) +#' @rdname column +#' @export +setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) + #' @rdname column #' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) @@ -603,6 +611,10 @@ setGeneric("n", function(x) { standardGeneric("n") }) #' @export setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) +#' @rdname column +#' @export +setGeneric("rint", function(x, ...) { standardGeneric("rint") }) + #' @rdname column #' @export setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) @@ -615,6 +627,14 @@ setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) #' @export setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) +#' @rdname column +#' @export +setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) + +#' @rdname column +#' @export +setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) + #' @rdname column #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1109e8fdba3fdc23b1c441ebab0bd1bec0fa02be..3e5658eb5b24bf9c2c23926fcffe212655935bfd 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -530,6 +530,7 @@ test_that("column operators", { c2 <- (- c + 1 - 2) * 3 / 4.0 c3 <- (c + c2 - c2) * c2 %% c2 c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) + c5 <- c2 ^ c3 ^ c4 }) test_that("column functions", { @@ -538,6 +539,29 @@ test_that("column functions", { c3 <- lower(c) + upper(c) + first(c) + last(c) c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") c5 <- n(c) + n_distinct(c) + c5 <- acos(c) + asin(c) + atan(c) + cbrt(c) + c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c) + c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) + c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) + c9 <- toDegrees(c) + toRadians(c) +}) + +test_that("column binary mathfunctions", { + lines <- c("{\"a\":1, \"b\":5}", + "{\"a\":2, \"b\":6}", + "{\"a\":3, \"b\":7}", + "{\"a\":4, \"b\":8}") + jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPathWithDup) + df <- jsonFile(sqlCtx, jsonPathWithDup) + expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) }) test_that("string operators", {