From 390b22fad69a33eb6daee25b6b858a2e768670a5 Mon Sep 17 00:00:00 2001 From: Sun Rui <rui.sun@intel.com> Date: Tue, 13 Oct 2015 22:31:23 -0700 Subject: [PATCH] [SPARK-10996] [SPARKR] Implement sampleBy() in DataFrameStatFunctions. Author: Sun Rui <rui.sun@intel.com> Closes #9023 from sun-rui/SPARK-10996. --- R/pkg/NAMESPACE | 3 ++- R/pkg/R/DataFrame.R | 14 ++++++-------- R/pkg/R/generics.R | 6 +++++- R/pkg/R/sparkR.R | 12 +++--------- R/pkg/R/stats.R | 32 ++++++++++++++++++++++++++++++++ R/pkg/R/utils.R | 18 ++++++++++++++++++ R/pkg/inst/tests/test_sparkSQL.R | 10 ++++++++++ 7 files changed, 76 insertions(+), 19 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ed9cd94e03..52f7a0106a 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -65,6 +65,7 @@ exportMethods("arrange", "repartition", "sample", "sample_frac", + "sampleBy", "saveAsParquetFile", "saveAsTable", "saveDF", @@ -254,4 +255,4 @@ export("structField", "structType.structField", "print.structType") -export("as.data.frame") \ No newline at end of file +export("as.data.frame") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index b7f5f978eb..993be82a47 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1831,17 +1831,15 @@ setMethod("fillna", if (length(colNames) == 0 || !all(colNames != "")) { stop("value should be an a named list with each name being a column name.") } - - # Convert to the named list to an environment to be passed to JVM - valueMap <- new.env() - for (col in colNames) { - # Check each item in the named list is of valid type - v <- value[[col]] + # Check each item in the named list is of valid type + lapply(value, function(v) { if (!(class(v) %in% c("integer", "numeric", "character"))) { stop("Each item in value should be an integer, numeric or charactor.") } - valueMap[[col]] <- v - } + }) + + # Convert to the named list to an environment to be passed to JVM + valueMap <- convertNamedListToEnv(value) # When value is a named list, caller is expected not to pass in cols if (!is.null(cols)) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index c106a00245..4a419f785e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -509,6 +509,10 @@ setGeneric("sample", setGeneric("sample_frac", function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) +#' @rdname statfunctions +#' @export +setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") }) + #' @rdname saveAsParquetFile #' @export setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) @@ -1006,4 +1010,4 @@ setGeneric("as.data.frame") #' @rdname attach #' @export -setGeneric("attach") \ No newline at end of file +setGeneric("attach") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index cc47110f54..9cf2f1a361 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -163,19 +163,13 @@ sparkR.init <- function( sparkHome <- suppressWarnings(normalizePath(sparkHome)) } - sparkEnvirMap <- new.env() - for (varname in names(sparkEnvir)) { - sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] - } + sparkEnvirMap <- convertNamedListToEnv(sparkEnvir) - sparkExecutorEnvMap <- new.env() - if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { + sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv) + if(is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) { sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) } - for (varname in names(sparkExecutorEnv)) { - sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] - } nonEmptyJars <- Filter(function(x) { x != "" }, jars) localJarPaths <- lapply(nonEmptyJars, diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 4928cf4d43..f79329b115 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -127,3 +127,35 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"), sct <- callJMethod(statFunctions, "freqItems", as.list(cols), support) collect(dataFrame(sct)) }) + +#' sampleBy +#' +#' Returns a stratified sample without replacement based on the fraction given on each stratum. +#' +#' @param x A SparkSQL DataFrame +#' @param col column that defines strata +#' @param fractions A named list giving sampling fraction for each stratum. If a stratum is +#' not specified, we treat its fraction as zero. +#' @param seed random seed +#' @return A new DataFrame that represents the stratified sample +#' +#' @rdname statfunctions +#' @name sampleBy +#' @export +#' @examples +#'\dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' sample <- sampleBy(df, "key", fractions, 36) +#' } +setMethod("sampleBy", + signature(x = "DataFrame", col = "character", + fractions = "list", seed = "numeric"), + function(x, col, fractions, seed) { + fractionsEnv <- convertNamedListToEnv(fractions) + + statFunctions <- callJMethod(x@sdf, "stat") + # Seed is expected to be Long on Scala side, here convert it to an integer + # due to SerDe limitation now. + sdf <- callJMethod(statFunctions, "sampleBy", col, fractionsEnv, as.integer(seed)) + dataFrame(sdf) + }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 94f16c7ac5..0b9e2957fe 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -605,3 +605,21 @@ structToList <- function(struct) { class(struct) <- "list" struct } + +# Convert a named list to an environment to be passed to JVM +convertNamedListToEnv <- function(namedList) { + # Make sure each item in the list has a name + names <- names(namedList) + stopifnot( + if (is.null(names)) { + length(namedList) == 0 + } else { + !any(is.na(names)) + }) + + env <- new.env() + for (name in names) { + env[[name]] <- namedList[[name]] + } + env +} diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 46cab7646d..e1b42b0804 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1416,6 +1416,16 @@ test_that("freqItems() on a DataFrame", { expect_identical(result[[2]], list(list(-1, -99))) }) +test_that("sampleBy() on a DataFrame", { + l <- lapply(c(0:99), function(i) { as.character(i %% 3) }) + df <- createDataFrame(sqlContext, l, "key") + fractions <- list("0" = 0.1, "1" = 0.2) + sample <- sampleBy(df, "key", fractions, 0) + result <- collect(orderBy(count(groupBy(sample, "key")), "key")) + expect_identical(as.list(result[1, ]), list(key = "0", count = 2)) + expect_identical(as.list(result[2, ]), list(key = "1", count = 10)) +}) + test_that("SQL error message is returned from JVM", { retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) expect_equal(grepl("Table Not Found: blah", retError), TRUE) -- GitLab