From 5a799fd8c3664da1fa9821ead6c0e25f561c6a8d Mon Sep 17 00:00:00 2001
From: zero323 <zero323@users.noreply.github.com>
Date: Sun, 14 May 2017 13:22:19 -0700
Subject: [PATCH] [SPARK-20726][SPARKR] wrapper for SQL broadcast

## What changes were proposed in this pull request?

- Adds R wrapper for `o.a.s.sql.functions.broadcast`.
- Renames `broadcast` to `broadcast_`.

## How was this patch tested?

Unit tests, check `check-cran.sh`.

Author: zero323 <zero323@users.noreply.github.com>

Closes #17965 from zero323/SPARK-20726.
---
 R/pkg/NAMESPACE                            |  1 +
 R/pkg/R/DataFrame.R                        | 29 ++++++++++++++++++++++
 R/pkg/R/context.R                          |  4 +--
 R/pkg/R/generics.R                         |  4 +++
 R/pkg/inst/tests/testthat/test_broadcast.R |  2 +-
 R/pkg/inst/tests/testthat/test_sparkSQL.R  |  5 ++++
 R/pkg/inst/tests/testthat/test_utils.R     |  2 +-
 7 files changed, 43 insertions(+), 4 deletions(-)

diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index ba0fe7708b..5c074d3c0f 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -84,6 +84,7 @@ exportClasses("SparkDataFrame")
 exportMethods("arrange",
               "as.data.frame",
               "attach",
+              "broadcast",
               "cache",
               "checkpoint",
               "coalesce",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index b56dddcb9f..aab2fc17ae 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -3769,3 +3769,32 @@ setMethod("alias",
             sdf <- callJMethod(object@sdf, "alias", data)
             dataFrame(sdf)
           })
+
+#' broadcast
+#'
+#' Return a new SparkDataFrame marked as small enough for use in broadcast joins.
+#'
+#' Equivalent to \code{hint(x, "broadcast")}.
+#'
+#' @param x a SparkDataFrame.
+#' @return a SparkDataFrame.
+#'
+#' @aliases broadcast,SparkDataFrame-method
+#' @family SparkDataFrame functions
+#' @rdname broadcast
+#' @name broadcast
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame(mtcars)
+#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg")
+#'
+#' head(join(df, broadcast(avg_mpg), df$cyl == avg_mpg$cyl))
+#' }
+#' @note broadcast since 2.3.0
+setMethod("broadcast",
+          signature(x = "SparkDataFrame"),
+          function(x) {
+            sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf)
+            dataFrame(sdf)
+          })
diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R
index 50856e3d98..8349b57a30 100644
--- a/R/pkg/R/context.R
+++ b/R/pkg/R/context.R
@@ -258,7 +258,7 @@ includePackage <- function(sc, pkg) {
 #'
 #' # Large Matrix object that we want to broadcast
 #' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000))
-#' randomMatBr <- broadcast(sc, randomMat)
+#' randomMatBr <- broadcastRDD(sc, randomMat)
 #'
 #' # Use the broadcast variable inside the function
 #' useBroadcast <- function(x) {
@@ -266,7 +266,7 @@ includePackage <- function(sc, pkg) {
 #' }
 #' sumRDD <- lapply(rdd, useBroadcast)
 #'}
-broadcast <- function(sc, object) {
+broadcastRDD <- function(sc, object) {
   objName <- as.character(substitute(object))
   serializedObj <- serialize(object, connection = NULL)
 
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 3c84bf8a48..514ca99d45 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -799,6 +799,10 @@ setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.d
 #' @export
 setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") })
 
+#' @rdname broadcast
+#' @export
+setGeneric("broadcast", function(x) { standardGeneric("broadcast") })
+
 ###################### Column Methods ##########################
 
 #' @rdname columnfunctions
diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R
index 254f8f522a..2c96740df7 100644
--- a/R/pkg/inst/tests/testthat/test_broadcast.R
+++ b/R/pkg/inst/tests/testthat/test_broadcast.R
@@ -29,7 +29,7 @@ test_that("using broadcast variable", {
   skip_on_cran()
 
   randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100))
-  randomMatBr <- broadcast(sc, randomMat)
+  randomMatBr <- broadcastRDD(sc, randomMat)
 
   useBroadcast <- function(x) {
     sum(SparkR:::value(randomMatBr) * x)
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 0ff2e02e75..b633b78d5b 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -2216,6 +2216,11 @@ test_that("join(), crossJoin() and merge() on a DataFrame", {
     explain(join(df1, hint(df2, "broadcast"), df1$id == df2$id))
   )
   expect_true(any(grepl("BroadcastHashJoin", execution_plan_hint)))
+
+  execution_plan_broadcast <- capture.output(
+    explain(join(df1, broadcast(df2), df1$id == df2$id))
+  )
+  expect_true(any(grepl("BroadcastHashJoin", execution_plan_broadcast)))
 })
 
 test_that("toJSON() on DataFrame", {
diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R
index 2fc6530d63..02691f0f64 100644
--- a/R/pkg/inst/tests/testthat/test_utils.R
+++ b/R/pkg/inst/tests/testthat/test_utils.R
@@ -136,7 +136,7 @@ test_that("cleanClosure on R functions", {
 
   # Test for broadcast variables.
   a <- matrix(nrow = 10, ncol = 10, data = rnorm(100))
-  aBroadcast <- broadcast(sc, a)
+  aBroadcast <- broadcastRDD(sc, a)
   normMultiply <- function(x) { norm(aBroadcast$value) * x }
   newnormMultiply <- SparkR:::cleanClosure(normMultiply)
   env <- environment(newnormMultiply)
-- 
GitLab