From 5a799fd8c3664da1fa9821ead6c0e25f561c6a8d Mon Sep 17 00:00:00 2001 From: zero323 <zero323@users.noreply.github.com> Date: Sun, 14 May 2017 13:22:19 -0700 Subject: [PATCH] [SPARK-20726][SPARKR] wrapper for SQL broadcast ## What changes were proposed in this pull request? - Adds R wrapper for `o.a.s.sql.functions.broadcast`. - Renames `broadcast` to `broadcast_`. ## How was this patch tested? Unit tests, check `check-cran.sh`. Author: zero323 <zero323@users.noreply.github.com> Closes #17965 from zero323/SPARK-20726. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 29 ++++++++++++++++++++++ R/pkg/R/context.R | 4 +-- R/pkg/R/generics.R | 4 +++ R/pkg/inst/tests/testthat/test_broadcast.R | 2 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 5 ++++ R/pkg/inst/tests/testthat/test_utils.R | 2 +- 7 files changed, 43 insertions(+), 4 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ba0fe7708b..5c074d3c0f 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -84,6 +84,7 @@ exportClasses("SparkDataFrame") exportMethods("arrange", "as.data.frame", "attach", + "broadcast", "cache", "checkpoint", "coalesce", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index b56dddcb9f..aab2fc17ae 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3769,3 +3769,32 @@ setMethod("alias", sdf <- callJMethod(object@sdf, "alias", data) dataFrame(sdf) }) + +#' broadcast +#' +#' Return a new SparkDataFrame marked as small enough for use in broadcast joins. +#' +#' Equivalent to \code{hint(x, "broadcast")}. +#' +#' @param x a SparkDataFrame. +#' @return a SparkDataFrame. +#' +#' @aliases broadcast,SparkDataFrame-method +#' @family SparkDataFrame functions +#' @rdname broadcast +#' @name broadcast +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg") +#' +#' head(join(df, broadcast(avg_mpg), df$cyl == avg_mpg$cyl)) +#' } +#' @note broadcast since 2.3.0 +setMethod("broadcast", + signature(x = "SparkDataFrame"), + function(x) { + sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf) + dataFrame(sdf) + }) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 50856e3d98..8349b57a30 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -258,7 +258,7 @@ includePackage <- function(sc, pkg) { #' #' # Large Matrix object that we want to broadcast #' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) -#' randomMatBr <- broadcast(sc, randomMat) +#' randomMatBr <- broadcastRDD(sc, randomMat) #' #' # Use the broadcast variable inside the function #' useBroadcast <- function(x) { @@ -266,7 +266,7 @@ includePackage <- function(sc, pkg) { #' } #' sumRDD <- lapply(rdd, useBroadcast) #'} -broadcast <- function(sc, object) { +broadcastRDD <- function(sc, object) { objName <- as.character(substitute(object)) serializedObj <- serialize(object, connection = NULL) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 3c84bf8a48..514ca99d45 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -799,6 +799,10 @@ setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.d #' @export setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") }) +#' @rdname broadcast +#' @export +setGeneric("broadcast", function(x) { standardGeneric("broadcast") }) + ###################### Column Methods ########################## #' @rdname columnfunctions diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index 254f8f522a..2c96740df7 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -29,7 +29,7 @@ test_that("using broadcast variable", { skip_on_cran() randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) - randomMatBr <- broadcast(sc, randomMat) + randomMatBr <- broadcastRDD(sc, randomMat) useBroadcast <- function(x) { sum(SparkR:::value(randomMatBr) * x) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 0ff2e02e75..b633b78d5b 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2216,6 +2216,11 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { explain(join(df1, hint(df2, "broadcast"), df1$id == df2$id)) ) expect_true(any(grepl("BroadcastHashJoin", execution_plan_hint))) + + execution_plan_broadcast <- capture.output( + explain(join(df1, broadcast(df2), df1$id == df2$id)) + ) + expect_true(any(grepl("BroadcastHashJoin", execution_plan_broadcast))) }) test_that("toJSON() on DataFrame", { diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 2fc6530d63..02691f0f64 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -136,7 +136,7 @@ test_that("cleanClosure on R functions", { # Test for broadcast variables. a <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) - aBroadcast <- broadcast(sc, a) + aBroadcast <- broadcastRDD(sc, a) normMultiply <- function(x) { norm(aBroadcast$value) * x } newnormMultiply <- SparkR:::cleanClosure(normMultiply) env <- environment(newnormMultiply) -- GitLab