diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 40f1f0f4429e03234cabea6da05cdcc6288cf572..75861d5de70927c055589a348230a42b2ca77fea 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2608,7 +2608,7 @@ setMethod("except", #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions -#' @aliases write.df,SparkDataFrame,character-method +#' @aliases write.df,SparkDataFrame-method #' @rdname write.df #' @name write.df #' @export @@ -2622,21 +2622,31 @@ setMethod("except", #' } #' @note write.df since 1.4.0 setMethod("write.df", - signature(df = "SparkDataFrame", path = "character"), - function(df, path, source = NULL, mode = "error", ...) { + signature(df = "SparkDataFrame"), + function(df, path = NULL, source = NULL, mode = "error", ...) { + if (!is.null(path) && !is.character(path)) { + stop("path should be charactor, NULL or omitted.") + } + if (!is.null(source) && !is.character(source)) { + stop("source should be character, NULL or omitted. It is the datasource specified ", + "in 'spark.sql.sources.default' configuration by default.") + } + if (!is.character(mode)) { + stop("mode should be charactor or omitted. It is 'error' by default.") + } if (is.null(source)) { source <- getDefaultSqlSource() } jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) if (!is.null(path)) { - options[["path"]] <- path + options[["path"]] <- path } write <- callJMethod(df@sdf, "write") write <- callJMethod(write, "format", source) write <- callJMethod(write, "mode", jmode) write <- callJMethod(write, "options", options) - write <- callJMethod(write, "save", path) + write <- handledCallJMethod(write, "save") }) #' @rdname write.df diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index ce531c3f888630ea3ff3d71f4433405eba7e72f2..baa87824beb91bdc018fe8393d54bb344d920889 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -771,6 +771,13 @@ dropTempView <- function(viewName) { #' @method read.df default #' @note read.df since 1.4.0 read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) { + if (!is.null(path) && !is.character(path)) { + stop("path should be charactor, NULL or omitted.") + } + if (!is.null(source) && !is.character(source)) { + stop("source should be character, NULL or omitted. It is the datasource specified ", + "in 'spark.sql.sources.default' configuration by default.") + } sparkSession <- getSparkSession() options <- varargsToEnv(...) if (!is.null(path)) { @@ -784,16 +791,16 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.string } if (!is.null(schema)) { stopifnot(class(schema) == "structType") - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source, - schema$jobj, options) + sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, + source, schema$jobj, options) } else { - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "loadDF", sparkSession, source, options) + sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, + source, options) } dataFrame(sdf) } -read.df <- function(x, ...) { +read.df <- function(x = NULL, ...) { dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", x, ...) } @@ -805,7 +812,7 @@ loadDF.default <- function(path = NULL, source = NULL, schema = NULL, ...) { read.df(path, source, schema, ...) } -loadDF <- function(x, ...) { +loadDF <- function(x = NULL, ...) { dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...) } diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 67a999da9bc26434db0c357be1c2d72c8b4f48fc..90a02e2778310015a46c7a08af1b74ce41c83c25 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -633,7 +633,7 @@ setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) #' @rdname write.df #' @export -setGeneric("write.df", function(df, path, source = NULL, mode = "error", ...) { +setGeneric("write.df", function(df, path = NULL, source = NULL, mode = "error", ...) { standardGeneric("write.df") }) @@ -732,7 +732,7 @@ setGeneric("withColumnRenamed", #' @rdname write.df #' @export -setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) +setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) #' @rdname randomSplit #' @export diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 248c57532b6cf37000b7e279a58f9ebae86995ea..e69666453480c7a825e71e40cf0d1088536f9dc2 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -698,6 +698,58 @@ isSparkRShell <- function() { grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE) } +# Works identically with `callJStatic(...)` but throws a pretty formatted exception. +handledCallJStatic <- function(cls, method, ...) { + result <- tryCatch(callJStatic(cls, method, ...), + error = function(e) { + captureJVMException(e, method) + }) + result +} + +# Works identically with `callJMethod(...)` but throws a pretty formatted exception. +handledCallJMethod <- function(obj, method, ...) { + result <- tryCatch(callJMethod(obj, method, ...), + error = function(e) { + captureJVMException(e, method) + }) + result +} + +captureJVMException <- function(e, method) { + rawmsg <- as.character(e) + if (any(grep("^Error in .*?: ", rawmsg))) { + # If the exception message starts with "Error in ...", this is possibly + # "Error in invokeJava(...)". Here, it replaces the characters to + # `paste("Error in", method, ":")` in order to identify which function + # was called in JVM side. + stacktrace <- strsplit(rawmsg, "Error in .*?: ")[[1]] + rmsg <- paste("Error in", method, ":") + stacktrace <- paste(rmsg[1], stacktrace[2]) + } else { + # Otherwise, do not convert the error message just in case. + stacktrace <- rawmsg + } + + if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) { + msg <- strsplit(stacktrace, "java.lang.IllegalArgumentException: ", fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "illegal argument - ", first), call. = FALSE) + } else if (any(grep("org.apache.spark.sql.AnalysisException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.AnalysisException: ", fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "analysis error - ", first), call. = FALSE) + } else { + stop(stacktrace, call. = FALSE) + } +} + # rbind a list of rows with raw (binary) columns # # @param inputData a list of rows, with each row a list diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9d874a0988716e7dc0ea14d0b0b1cc8f3100c37c..f5ab601f274fe7a4294051d43cb4b11ae1f75146 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2544,6 +2544,41 @@ test_that("Spark version from SparkSession", { expect_equal(ver, version) }) +test_that("Call DataFrameWriter.save() API in Java without path and check argument types", { + df <- read.df(jsonPath, "json") + # This tests if the exception is thrown from JVM not from SparkR side. + # It makes sure that we can omit path argument in write.df API and then it calls + # DataFrameWriter.save() without path. + expect_error(write.df(df, source = "csv"), + "Error in save : illegal argument - 'path' is not specified") + + # Arguments checking in R side. + expect_error(write.df(df, "data.tmp", source = c(1, 2)), + paste("source should be character, NULL or omitted. It is the datasource specified", + "in 'spark.sql.sources.default' configuration by default.")) + expect_error(write.df(df, path = c(3)), + "path should be charactor, NULL or omitted.") + expect_error(write.df(df, mode = TRUE), + "mode should be charactor or omitted. It is 'error' by default.") +}) + +test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { + # This tests if the exception is thrown from JVM not from SparkR side. + # It makes sure that we can omit path argument in read.df API and then it calls + # DataFrameWriter.load() without path. + expect_error(read.df(source = "json"), + paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .", + "It must be specified manually")) + expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") + + # Arguments checking in R side. + expect_error(read.df(path = c(3)), + "path should be charactor, NULL or omitted.") + expect_error(read.df(jsonPath, source = c(1, 2)), + paste("source should be character, NULL or omitted. It is the datasource specified", + "in 'spark.sql.sources.default' configuration by default.")) +}) + unlink(parquetPath) unlink(orcPath) unlink(jsonPath) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 77f25292f3f290fe0682dcd58cfa92f85c760dcc..69ed5549168b1834c54ce579dfb6154ec1b9c2db 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -166,6 +166,16 @@ test_that("convertToJSaveMode", { 'mode should be one of "append", "overwrite", "error", "ignore"') #nolint }) +test_that("captureJVMException", { + method <- "getSQLDataType" + expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, + "unknown"), + error = function(e) { + captureJVMException(e, method) + }), + "Error in getSQLDataType : illegal argument - Invalid type unknown") +}) + test_that("hashCode", { expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) })