diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ba0fe7708bcc32999d2c7c5d297a9c79f4fe1f1a..5c074d3c0fd40cff7a623bfedc28af19525a45e6 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -84,6 +84,7 @@ exportClasses("SparkDataFrame") exportMethods("arrange", "as.data.frame", "attach", + "broadcast", "cache", "checkpoint", "coalesce", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index b56dddcb9f2ef45493407c19b5d90ded09ec2fed..aab2fc17aedaffae5b37fdb4c43e0ab269188689 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3769,3 +3769,32 @@ setMethod("alias", sdf <- callJMethod(object@sdf, "alias", data) dataFrame(sdf) }) + +#' broadcast +#' +#' Return a new SparkDataFrame marked as small enough for use in broadcast joins. +#' +#' Equivalent to \code{hint(x, "broadcast")}. +#' +#' @param x a SparkDataFrame. +#' @return a SparkDataFrame. +#' +#' @aliases broadcast,SparkDataFrame-method +#' @family SparkDataFrame functions +#' @rdname broadcast +#' @name broadcast +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg") +#' +#' head(join(df, broadcast(avg_mpg), df$cyl == avg_mpg$cyl)) +#' } +#' @note broadcast since 2.3.0 +setMethod("broadcast", + signature(x = "SparkDataFrame"), + function(x) { + sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf) + dataFrame(sdf) + }) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 50856e3d9856c8e858029b51329f885495f1e948..8349b57a30a938dc9b4cdd9d98b1054aa7ad8a59 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -258,7 +258,7 @@ includePackage <- function(sc, pkg) { #' #' # Large Matrix object that we want to broadcast #' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) -#' randomMatBr <- broadcast(sc, randomMat) +#' randomMatBr <- broadcastRDD(sc, randomMat) #' #' # Use the broadcast variable inside the function #' useBroadcast <- function(x) { @@ -266,7 +266,7 @@ includePackage <- function(sc, pkg) { #' } #' sumRDD <- lapply(rdd, useBroadcast) #'} -broadcast <- function(sc, object) { +broadcastRDD <- function(sc, object) { objName <- as.character(substitute(object)) serializedObj <- serialize(object, connection = NULL) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 3c84bf8a4803e5d7ba95b63fcf6a593cd7fc17ff..514ca99d45cd343fcb5474214bf3d6006f095bae 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -799,6 +799,10 @@ setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.d #' @export setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") }) +#' @rdname broadcast +#' @export +setGeneric("broadcast", function(x) { standardGeneric("broadcast") }) + ###################### Column Methods ########################## #' @rdname columnfunctions diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index 254f8f522a70823b73ecf8b8619c879f4fc8a8ca..2c96740df77bbee56a1aac719b14097fbded72a3 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -29,7 +29,7 @@ test_that("using broadcast variable", { skip_on_cran() randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) - randomMatBr <- broadcast(sc, randomMat) + randomMatBr <- broadcastRDD(sc, randomMat) useBroadcast <- function(x) { sum(SparkR:::value(randomMatBr) * x) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 0ff2e02e75a98b42f099de300e087175f0282a02..b633b78d5bb4dbec4ff9c4541f1c845228a49cfa 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2216,6 +2216,11 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { explain(join(df1, hint(df2, "broadcast"), df1$id == df2$id)) ) expect_true(any(grepl("BroadcastHashJoin", execution_plan_hint))) + + execution_plan_broadcast <- capture.output( + explain(join(df1, broadcast(df2), df1$id == df2$id)) + ) + expect_true(any(grepl("BroadcastHashJoin", execution_plan_broadcast))) }) test_that("toJSON() on DataFrame", { diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 2fc6530d63e545ebf3b290eee1bb7c9de6f65c16..02691f0f64314421f44d4a9a574448b54b034e21 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -136,7 +136,7 @@ test_that("cleanClosure on R functions", { # Test for broadcast variables. a <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) - aBroadcast <- broadcast(sc, a) + aBroadcast <- broadcastRDD(sc, a) normMultiply <- function(x) { norm(aBroadcast$value) * x } newnormMultiply <- SparkR:::cleanClosure(normMultiply) env <- environment(newnormMultiply)