Skip to content
Snippets Groups Projects
Commit 390b22fa authored by Sun Rui's avatar Sun Rui Committed by Shivaram Venkataraman
Browse files

[SPARK-10996] [SPARKR] Implement sampleBy() in DataFrameStatFunctions.

Author: Sun Rui <rui.sun@intel.com>

Closes #9023 from sun-rui/SPARK-10996.
parent 8b328857
No related branches found
No related tags found
No related merge requests found
...@@ -65,6 +65,7 @@ exportMethods("arrange", ...@@ -65,6 +65,7 @@ exportMethods("arrange",
"repartition", "repartition",
"sample", "sample",
"sample_frac", "sample_frac",
"sampleBy",
"saveAsParquetFile", "saveAsParquetFile",
"saveAsTable", "saveAsTable",
"saveDF", "saveDF",
...@@ -254,4 +255,4 @@ export("structField", ...@@ -254,4 +255,4 @@ export("structField",
"structType.structField", "structType.structField",
"print.structType") "print.structType")
export("as.data.frame") export("as.data.frame")
\ No newline at end of file
...@@ -1831,17 +1831,15 @@ setMethod("fillna", ...@@ -1831,17 +1831,15 @@ setMethod("fillna",
if (length(colNames) == 0 || !all(colNames != "")) { if (length(colNames) == 0 || !all(colNames != "")) {
stop("value should be an a named list with each name being a column name.") stop("value should be an a named list with each name being a column name.")
} }
# Check each item in the named list is of valid type
# Convert to the named list to an environment to be passed to JVM lapply(value, function(v) {
valueMap <- new.env()
for (col in colNames) {
# Check each item in the named list is of valid type
v <- value[[col]]
if (!(class(v) %in% c("integer", "numeric", "character"))) { if (!(class(v) %in% c("integer", "numeric", "character"))) {
stop("Each item in value should be an integer, numeric or charactor.") stop("Each item in value should be an integer, numeric or charactor.")
} }
valueMap[[col]] <- v })
}
# Convert to the named list to an environment to be passed to JVM
valueMap <- convertNamedListToEnv(value)
# When value is a named list, caller is expected not to pass in cols # When value is a named list, caller is expected not to pass in cols
if (!is.null(cols)) { if (!is.null(cols)) {
......
...@@ -509,6 +509,10 @@ setGeneric("sample", ...@@ -509,6 +509,10 @@ setGeneric("sample",
setGeneric("sample_frac", setGeneric("sample_frac",
function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") })
#' @rdname statfunctions
#' @export
setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") })
#' @rdname saveAsParquetFile #' @rdname saveAsParquetFile
#' @export #' @export
setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") })
...@@ -1006,4 +1010,4 @@ setGeneric("as.data.frame") ...@@ -1006,4 +1010,4 @@ setGeneric("as.data.frame")
#' @rdname attach #' @rdname attach
#' @export #' @export
setGeneric("attach") setGeneric("attach")
\ No newline at end of file
...@@ -163,19 +163,13 @@ sparkR.init <- function( ...@@ -163,19 +163,13 @@ sparkR.init <- function(
sparkHome <- suppressWarnings(normalizePath(sparkHome)) sparkHome <- suppressWarnings(normalizePath(sparkHome))
} }
sparkEnvirMap <- new.env() sparkEnvirMap <- convertNamedListToEnv(sparkEnvir)
for (varname in names(sparkEnvir)) {
sparkEnvirMap[[varname]] <- sparkEnvir[[varname]]
}
sparkExecutorEnvMap <- new.env() sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv)
if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { if(is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) {
sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <-
paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH"))
} }
for (varname in names(sparkExecutorEnv)) {
sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]]
}
nonEmptyJars <- Filter(function(x) { x != "" }, jars) nonEmptyJars <- Filter(function(x) { x != "" }, jars)
localJarPaths <- lapply(nonEmptyJars, localJarPaths <- lapply(nonEmptyJars,
......
...@@ -127,3 +127,35 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"), ...@@ -127,3 +127,35 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"),
sct <- callJMethod(statFunctions, "freqItems", as.list(cols), support) sct <- callJMethod(statFunctions, "freqItems", as.list(cols), support)
collect(dataFrame(sct)) collect(dataFrame(sct))
}) })
#' sampleBy
#'
#' Returns a stratified sample without replacement based on the fraction given on each stratum.
#'
#' @param x A SparkSQL DataFrame
#' @param col column that defines strata
#' @param fractions A named list giving sampling fraction for each stratum. If a stratum is
#' not specified, we treat its fraction as zero.
#' @param seed random seed
#' @return A new DataFrame that represents the stratified sample
#'
#' @rdname statfunctions
#' @name sampleBy
#' @export
#' @examples
#'\dontrun{
#' df <- jsonFile(sqlContext, "/path/to/file.json")
#' sample <- sampleBy(df, "key", fractions, 36)
#' }
setMethod("sampleBy",
signature(x = "DataFrame", col = "character",
fractions = "list", seed = "numeric"),
function(x, col, fractions, seed) {
fractionsEnv <- convertNamedListToEnv(fractions)
statFunctions <- callJMethod(x@sdf, "stat")
# Seed is expected to be Long on Scala side, here convert it to an integer
# due to SerDe limitation now.
sdf <- callJMethod(statFunctions, "sampleBy", col, fractionsEnv, as.integer(seed))
dataFrame(sdf)
})
...@@ -605,3 +605,21 @@ structToList <- function(struct) { ...@@ -605,3 +605,21 @@ structToList <- function(struct) {
class(struct) <- "list" class(struct) <- "list"
struct struct
} }
# Convert a named list to an environment to be passed to JVM
convertNamedListToEnv <- function(namedList) {
# Make sure each item in the list has a name
names <- names(namedList)
stopifnot(
if (is.null(names)) {
length(namedList) == 0
} else {
!any(is.na(names))
})
env <- new.env()
for (name in names) {
env[[name]] <- namedList[[name]]
}
env
}
...@@ -1416,6 +1416,16 @@ test_that("freqItems() on a DataFrame", { ...@@ -1416,6 +1416,16 @@ test_that("freqItems() on a DataFrame", {
expect_identical(result[[2]], list(list(-1, -99))) expect_identical(result[[2]], list(list(-1, -99)))
}) })
test_that("sampleBy() on a DataFrame", {
l <- lapply(c(0:99), function(i) { as.character(i %% 3) })
df <- createDataFrame(sqlContext, l, "key")
fractions <- list("0" = 0.1, "1" = 0.2)
sample <- sampleBy(df, "key", fractions, 0)
result <- collect(orderBy(count(groupBy(sample, "key")), "key"))
expect_identical(as.list(result[1, ]), list(key = "0", count = 2))
expect_identical(as.list(result[2, ]), list(key = "1", count = 10))
})
test_that("SQL error message is returned from JVM", { test_that("SQL error message is returned from JVM", {
retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
expect_equal(grepl("Table Not Found: blah", retError), TRUE) expect_equal(grepl("Table Not Found: blah", retError), TRUE)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment