diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index a8cf53fd46c2ec519c39f0c52353677ab57d3825..8db4d5ca1ee532c5cbc65b3a22e2234a0aa7588a 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -62,6 +62,7 @@ exportMethods("arrange", "filter", "first", "freqItems", + "gapply", "group_by", "groupBy", "head", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 0ff350d44d4b3160d07f3b0a1c56e491243b2172..9a9b3f7ecae164e7548327b1dbeff0d583deb4c4 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1181,7 +1181,7 @@ dapplyInternal <- function(x, func, schema) { #' func should have only one parameter, to which a data.frame corresponds #' to each partition will be passed. #' The output of func should be a data.frame. -#' @param schema The schema of the resulting DataFrame after the function is applied. +#' @param schema The schema of the resulting SparkDataFrame after the function is applied. #' It must match the output of func. #' @family SparkDataFrame functions #' @rdname dapply @@ -1267,6 +1267,86 @@ setMethod("dapplyCollect", ldf }) +#' gapply +#' +#' Group the SparkDataFrame using the specified columns and apply the R function to each +#' group. +#' +#' @param x A SparkDataFrame +#' @param cols Grouping columns +#' @param func A function to be applied to each group partition specified by grouping +#' column of the SparkDataFrame. The function `func` takes as argument +#' a key - grouping columns and a data frame - a local R data.frame. +#' The output of `func` is a local R data.frame. +#' @param schema The schema of the resulting SparkDataFrame after the function is applied. +#' The schema must match to output of `func`. It has to be defined for each +#' output column with preferred output column name and corresponding data type. +#' @family SparkDataFrame functions +#' @rdname gapply +#' @name gapply +#' @export +#' @examples +#' +#' \dontrun{ +#' Computes the arithmetic mean of the second column by grouping +#' on the first and third columns. Output the grouping values and the average. +#' +#' df <- createDataFrame ( +#' list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), +#' c("a", "b", "c", "d")) +#' +#' Here our output contains three columns, the key which is a combination of two +#' columns with data types integer and string and the mean which is a double. +#' schema <- structType(structField("a", "integer"), structField("c", "string"), +#' structField("avg", "double")) +#' df1 <- gapply( +#' df, +#' list("a", "c"), +#' function(key, x) { +#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) +#' }, +#' schema) +#' collect(df1) +#' +#' Result +#' ------ +#' a c avg +#' 3 3 3.0 +#' 1 1 1.5 +#' +#' Fits linear models on iris dataset by grouping on the 'Species' column and +#' using 'Sepal_Length' as a target variable, 'Sepal_Width', 'Petal_Length' +#' and 'Petal_Width' as training features. +#' +#' df <- createDataFrame (iris) +#' schema <- structType(structField("(Intercept)", "double"), +#' structField("Sepal_Width", "double"),structField("Petal_Length", "double"), +#' structField("Petal_Width", "double")) +#' df1 <- gapply( +#' df, +#' list(df$"Species"), +#' function(key, x) { +#' m <- suppressWarnings(lm(Sepal_Length ~ +#' Sepal_Width + Petal_Length + Petal_Width, x)) +#' data.frame(t(coef(m))) +#' }, schema) +#' collect(df1) +#' +#'Result +#'--------- +#' Model (Intercept) Sepal_Width Petal_Length Petal_Width +#' 1 0.699883 0.3303370 0.9455356 -0.1697527 +#' 2 1.895540 0.3868576 0.9083370 -0.6792238 +#' 3 2.351890 0.6548350 0.2375602 0.2521257 +#' +#'} +setMethod("gapply", + signature(x = "SparkDataFrame"), + function(x, cols, func, schema) { + grouped <- do.call("groupBy", c(x, cols)) + gapply(grouped, func, schema) + }) + ############################## RDD Map Functions ################################## # All of the following functions mirror the existing RDD map functions, # # but allow for use with DataFrames by first converting to an RRDD before calling # diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index ce071b1a848bbd0557338dcd61d0ff756aa13cde..0e99b171cabebbc7949dcb1e91ac39263e717db2 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -197,6 +197,36 @@ readMultipleObjects <- function(inputCon) { data # this is a list of named lists now } +readMultipleObjectsWithKeys <- function(inputCon) { + # readMultipleObjectsWithKeys will read multiple continuous objects from + # a DataOutputStream. There is no preceding field telling the count + # of the objects, so the number of objects varies, we try to read + # all objects in a loop until the end of the stream. This function + # is for use by gapply. Each group of rows is followed by the grouping + # key for this group which is then followed by next group. + keys <- list() + data <- list() + subData <- list() + while (TRUE) { + # If reaching the end of the stream, type returned should be "". + type <- readType(inputCon) + if (type == "") { + break + } else if (type == "r") { + type <- readType(inputCon) + # A grouping boundary detected + key <- readTypedObject(inputCon, type) + index <- length(data) + 1L + data[[index]] <- subData + keys[[index]] <- key + subData <- list() + } else { + subData[[length(subData) + 1L]] <- readTypedObject(inputCon, type) + } + } + list(keys = keys, data = data) # this is a list of keys and corresponding data +} + readRowList <- function(obj) { # readRowList is meant for use inside an lapply. As a result, it is # necessary to open a standalone connection for the row and consume diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 50fc204f998a50db76bd8bcdbb1b2cfceae38f0c..40a96d8991a5a740e78c68764176fa25ea3ecec6 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -454,6 +454,10 @@ setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") }) #' @export setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") }) +#' @rdname gapply +#' @export +setGeneric("gapply", function(x, ...) { standardGeneric("gapply") }) + #' @rdname summary #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 08f4a490c883ef413dd4a29d0eb6a7b71fdbc5e5..b7047769175a3a13d97b5bd87c4a78f436b0571d 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -142,3 +142,65 @@ createMethods <- function() { } createMethods() + +#' gapply +#' +#' Applies a R function to each group in the input GroupedData +#' +#' @param x a GroupedData +#' @param func A function to be applied to each group partition specified by GroupedData. +#' The function `func` takes as argument a key - grouping columns and +#' a data frame - a local R data.frame. +#' The output of `func` is a local R data.frame. +#' @param schema The schema of the resulting SparkDataFrame after the function is applied. +#' The schema must match to output of `func`. It has to be defined for each +#' output column with preferred output column name and corresponding data type. +#' @return a SparkDataFrame +#' @rdname gapply +#' @name gapply +#' @examples +#' \dontrun{ +#' Computes the arithmetic mean of the second column by grouping +#' on the first and third columns. Output the grouping values and the average. +#' +#' df <- createDataFrame ( +#' list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), +#' c("a", "b", "c", "d")) +#' +#' Here our output contains three columns, the key which is a combination of two +#' columns with data types integer and string and the mean which is a double. +#' schema <- structType(structField("a", "integer"), structField("c", "string"), +#' structField("avg", "double")) +#' df1 <- gapply( +#' df, +#' list("a", "c"), +#' function(key, x) { +#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) +#' }, +#' schema) +#' collect(df1) +#' +#' Result +#' ------ +#' a c avg +#' 3 3 3.0 +#' 1 1 1.5 +#' } +setMethod("gapply", + signature(x = "GroupedData"), + function(x, func, schema) { + try(if (is.null(schema)) stop("schema cannot be NULL")) + packageNamesArr <- serialize(.sparkREnv[[".packages"]], + connection = NULL) + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) + sdf <- callJStatic( + "org.apache.spark.sql.api.r.SQLUtils", + "gapply", + x@sgd, + serialize(cleanClosure(func), connection = NULL), + packageNamesArr, + broadcastArr, + schema$jobj) + dataFrame(sdf) + }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index d1ca3b726fe0b2a93494cf9d2da14e6502481fe9..c11930ada63cea3b66504600c3bf5f9da26cbb83 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2146,6 +2146,71 @@ test_that("repartition by columns on DataFrame", { expect_equal(nrow(df1), 2) }) +test_that("gapply() on a DataFrame", { + df <- createDataFrame ( + list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), + c("a", "b", "c", "d")) + expected <- collect(df) + df1 <- gapply(df, list("a"), function(key, x) { x }, schema(df)) + actual <- collect(df1) + expect_identical(actual, expected) + + # Computes the sum of second column by grouping on the first and third columns + # and checks if the sum is larger than 2 + schema <- structType(structField("a", "integer"), structField("e", "boolean")) + df2 <- gapply( + df, + list(df$"a", df$"c"), + function(key, x) { + y <- data.frame(key[1], sum(x$b) > 2) + }, + schema) + actual <- collect(df2)$e + expected <- c(TRUE, TRUE) + expect_identical(actual, expected) + + # Computes the arithmetic mean of the second column by grouping + # on the first and third columns. Output the groupping value and the average. + schema <- structType(structField("a", "integer"), structField("c", "string"), + structField("avg", "double")) + df3 <- gapply( + df, + list("a", "c"), + function(key, x) { + y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) + }, + schema) + actual <- collect(df3) + actual <- actual[order(actual$a), ] + rownames(actual) <- NULL + expected <- collect(select(df, "a", "b", "c")) + expected <- data.frame(aggregate(expected$b, by = list(expected$a, expected$c), FUN = mean)) + colnames(expected) <- c("a", "c", "avg") + expected <- expected[order(expected$a), ] + rownames(expected) <- NULL + expect_identical(actual, expected) + + irisDF <- suppressWarnings(createDataFrame (iris)) + schema <- structType(structField("Sepal_Length", "double"), structField("Avg", "double")) + # Groups by `Sepal_Length` and computes the average for `Sepal_Width` + df4 <- gapply( + cols = list("Sepal_Length"), + irisDF, + function(key, x) { + y <- data.frame(key, mean(x$Sepal_Width), stringsAsFactors = FALSE) + }, + schema) + actual <- collect(df4) + actual <- actual[order(actual$Sepal_Length), ] + rownames(actual) <- NULL + agg_local_df <- data.frame(aggregate(iris$Sepal.Width, by = list(iris$Sepal.Length), FUN = mean), + stringsAsFactors = FALSE) + colnames(agg_local_df) <- c("Sepal_Length", "Avg") + expected <- agg_local_df[order(agg_local_df$Sepal_Length), ] + rownames(expected) <- NULL + expect_identical(actual, expected) +}) + test_that("Window functions on a DataFrame", { setHiveContext(sc) df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")), diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 40cda0c5ef9c146adcaab4743cb9b619a76d58d3..debf0180183a40e4e2b2ebd55b2a553af13e18ef 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -27,6 +27,54 @@ elapsedSecs <- function() { proc.time()[3] } +compute <- function(mode, partition, serializer, deserializer, key, + colNames, computeFunc, inputData) { + if (mode > 0) { + if (deserializer == "row") { + # Transform the list of rows into a data.frame + # Note that the optional argument stringsAsFactors for rbind is + # available since R 3.2.4. So we set the global option here. + oldOpt <- getOption("stringsAsFactors") + options(stringsAsFactors = FALSE) + inputData <- do.call(rbind.data.frame, inputData) + options(stringsAsFactors = oldOpt) + + names(inputData) <- colNames + } else { + # Check to see if inputData is a valid data.frame + stopifnot(deserializer == "byte") + stopifnot(class(inputData) == "data.frame") + } + + if (mode == 2) { + output <- computeFunc(key, inputData) + } else { + output <- computeFunc(inputData) + } + if (serializer == "row") { + # Transform the result data.frame back to a list of rows + output <- split(output, seq(nrow(output))) + } else { + # Serialize the ouput to a byte array + stopifnot(serializer == "byte") + } + } else { + output <- computeFunc(partition, inputData) + } + return (output) +} + +outputResult <- function(serializer, output, outputCon) { + if (serializer == "byte") { + SparkR:::writeRawSerialize(outputCon, output) + } else if (serializer == "row") { + SparkR:::writeRowSerialize(outputCon, output) + } else { + # write lines one-by-one with flag + lapply(output, function(line) SparkR:::writeString(outputCon, line)) + } +} + # Constants specialLengths <- list(END_OF_STERAM = 0L, TIMING_DATA = -1L) @@ -79,75 +127,71 @@ if (numBroadcastVars > 0) { # Timing broadcast broadcastElap <- elapsedSecs() +# Initial input timing +inputElap <- broadcastElap # If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int # as number of partitions to create. numPartitions <- SparkR:::readInt(inputCon) -isDataFrame <- as.logical(SparkR:::readInt(inputCon)) +# 0 - RDD mode, 1 - dapply mode, 2 - gapply mode +mode <- SparkR:::readInt(inputCon) -# If isDataFrame, then read column names -if (isDataFrame) { +if (mode > 0) { colNames <- SparkR:::readObject(inputCon) } isEmpty <- SparkR:::readInt(inputCon) +computeInputElapsDiff <- 0 +outputComputeElapsDiff <- 0 if (isEmpty != 0) { - if (numPartitions == -1) { if (deserializer == "byte") { # Now read as many characters as described in funcLen data <- SparkR:::readDeserialize(inputCon) } else if (deserializer == "string") { data <- as.list(readLines(inputCon)) + } else if (deserializer == "row" && mode == 2) { + dataWithKeys <- SparkR:::readMultipleObjectsWithKeys(inputCon) + keys <- dataWithKeys$keys + data <- dataWithKeys$data } else if (deserializer == "row") { data <- SparkR:::readMultipleObjects(inputCon) } + # Timing reading input data for execution inputElap <- elapsedSecs() - - if (isDataFrame) { - if (deserializer == "row") { - # Transform the list of rows into a data.frame - # Note that the optional argument stringsAsFactors for rbind is - # available since R 3.2.4. So we set the global option here. - oldOpt <- getOption("stringsAsFactors") - options(stringsAsFactors = FALSE) - data <- do.call(rbind.data.frame, data) - options(stringsAsFactors = oldOpt) - - names(data) <- colNames - } else { - # Check to see if data is a valid data.frame - stopifnot(deserializer == "byte") - stopifnot(class(data) == "data.frame") - } - output <- computeFunc(data) - if (serializer == "row") { - # Transform the result data.frame back to a list of rows - output <- split(output, seq(nrow(output))) - } else { - # Serialize the ouput to a byte array - stopifnot(serializer == "byte") + if (mode > 0) { + if (mode == 1) { + output <- compute(mode, partition, serializer, deserializer, NULL, + colNames, computeFunc, data) + } else { + # gapply mode + for (i in 1:length(data)) { + # Timing reading input data for execution + inputElap <- elapsedSecs() + output <- compute(mode, partition, serializer, deserializer, keys[[i]], + colNames, computeFunc, data[[i]]) + computeElap <- elapsedSecs() + outputResult(serializer, output, outputCon) + outputElap <- elapsedSecs() + computeInputElapsDiff <- computeInputElapsDiff + (computeElap - inputElap) + outputComputeElapsDiff <- outputComputeElapsDiff + (outputElap - computeElap) + } } } else { - output <- computeFunc(partition, data) + output <- compute(mode, partition, serializer, deserializer, NULL, + colNames, computeFunc, data) } - - # Timing computing - computeElap <- elapsedSecs() - - if (serializer == "byte") { - SparkR:::writeRawSerialize(outputCon, output) - } else if (serializer == "row") { - SparkR:::writeRowSerialize(outputCon, output) - } else { - # write lines one-by-one with flag - lapply(output, function(line) SparkR:::writeString(outputCon, line)) + if (mode != 2) { + # Not a gapply mode + computeElap <- elapsedSecs() + outputResult(serializer, output, outputCon) + outputElap <- elapsedSecs() + computeInputElapsDiff <- computeElap - inputElap + outputComputeElapsDiff <- outputElap - computeElap } - # Timing output - outputElap <- elapsedSecs() } else { if (deserializer == "byte") { # Now read as many characters as described in funcLen @@ -189,11 +233,9 @@ if (isEmpty != 0) { } # Timing output outputElap <- elapsedSecs() + computeInputElapsDiff <- computeElap - inputElap + outputComputeElapsDiff <- outputElap - computeElap } -} else { - inputElap <- broadcastElap - computeElap <- broadcastElap - outputElap <- broadcastElap } # Report timing @@ -202,8 +244,8 @@ SparkR:::writeDouble(outputCon, bootTime) SparkR:::writeDouble(outputCon, initElap - bootElap) # init SparkR:::writeDouble(outputCon, broadcastElap - initElap) # broadcast SparkR:::writeDouble(outputCon, inputElap - broadcastElap) # input -SparkR:::writeDouble(outputCon, computeElap - inputElap) # compute -SparkR:::writeDouble(outputCon, outputElap - computeElap) # output +SparkR:::writeDouble(outputCon, computeInputElapsDiff) # compute +SparkR:::writeDouble(outputCon, outputComputeElapsDiff) # output # End of output SparkR:::writeInt(outputCon, specialLengths$END_OF_STERAM) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 24ad689f8321c4913aa9d6aeaaac7426ed9ccaf1..496fdf851f7db1b70e7a179c16c2ee93e9cdf1a2 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -40,7 +40,8 @@ private[spark] class RRunner[U]( broadcastVars: Array[Broadcast[Object]], numPartitions: Int = -1, isDataFrame: Boolean = false, - colNames: Array[String] = null) + colNames: Array[String] = null, + mode: Int = RRunnerModes.RDD) extends Logging { private var bootTime: Double = _ private var dataStream: DataInputStream = _ @@ -148,8 +149,7 @@ private[spark] class RRunner[U]( } dataOut.writeInt(numPartitions) - - dataOut.writeInt(if (isDataFrame) 1 else 0) + dataOut.writeInt(mode) if (isDataFrame) { SerDe.writeObject(dataOut, colNames) @@ -180,6 +180,13 @@ private[spark] class RRunner[U]( for (elem <- iter) { elem match { + case (key, innerIter: Iterator[_]) => + for (innerElem <- innerIter) { + writeElem(innerElem) + } + // Writes key which can be used as a boundary in group-aggregate + dataOut.writeByte('r') + writeElem(key) case (key, value) => writeElem(key) writeElem(value) @@ -187,6 +194,7 @@ private[spark] class RRunner[U]( writeElem(elem) } } + stream.flush() } catch { // TODO: We should propagate this error to the task thread @@ -268,6 +276,12 @@ private object SpecialLengths { val TIMING_DATA = -1 } +private[spark] object RRunnerModes { + val RDD = 0 + val DATAFRAME_DAPPLY = 1 + val DATAFRAME_GAPPLY = 2 +} + private[r] class BufferedStreamThread( in: InputStream, name: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 78e8822b6405a2e7925285277e9903fc504e2594..7beeeb4f04bf0debd512f9c72c80ced541069ee9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -246,6 +246,55 @@ case class MapGroups( outputObjAttr: Attribute, child: LogicalPlan) extends UnaryNode with ObjectProducer +/** Factory for constructing new `FlatMapGroupsInR` nodes. */ +object FlatMapGroupsInR { + def apply( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + schema: StructType, + keyDeserializer: Expression, + valueDeserializer: Expression, + inputSchema: StructType, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + child: LogicalPlan): LogicalPlan = { + val mapped = FlatMapGroupsInR( + func, + packageNames, + broadcastVars, + inputSchema, + schema, + UnresolvedDeserializer(keyDeserializer, groupingAttributes), + UnresolvedDeserializer(valueDeserializer, dataAttributes), + groupingAttributes, + dataAttributes, + CatalystSerde.generateObjAttr(RowEncoder(schema)), + child) + CatalystSerde.serialize(mapped)(RowEncoder(schema)) + } +} + +case class FlatMapGroupsInR( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + inputSchema: StructType, + outputSchema: StructType, + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectProducer{ + + override lazy val schema = outputSchema + + override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema, + keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, + child) +} + /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { def apply[K : Encoder, L : Encoder, R : Encoder, OUT : Encoder]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 49b6eab8db5b0d2080ad702d8796004b7615b6af..1aa5767038d539e60045ea95b8c8d07f557029cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -20,14 +20,18 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.language.implicitConversions +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.api.r.SQLUtils._ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, Pivot} import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.NumericType +import org.apache.spark.sql.types.StructType /** * A set of methods for aggregations on a [[DataFrame]], created by [[Dataset.groupBy]]. @@ -381,6 +385,48 @@ class RelationalGroupedDataset protected[sql]( def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { pivot(pivotColumn, values.asScala) } + + /** + * Applies the given serialized R function `func` to each group of data. For each unique group, + * the function will be passed the group key and an iterator that contains all of the elements in + * the group. The function can return an iterator containing elements of an arbitrary type which + * will be returned as a new [[DataFrame]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 2.0.0 + */ + private[sql] def flatMapGroupsInR( + f: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + outputSchema: StructType): DataFrame = { + val groupingNamedExpressions = groupingExprs.map(alias) + val groupingCols = groupingNamedExpressions.map(Column(_)) + val groupingDataFrame = df.select(groupingCols : _*) + val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) + Dataset.ofRows( + df.sparkSession, + FlatMapGroupsInR( + f, + packageNames, + broadcastVars, + outputSchema, + groupingDataFrame.exprEnc.deserializer, + df.exprEnc.deserializer, + df.exprEnc.schema, + groupingAttributes, + df.logicalPlan.output, + df.logicalPlan)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 486a440b6f9a8cc42d0d6dbc24e55a39987521da..fe426fa3c7e8a7b7ce8aa2b5b688f777b9b12379 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -26,7 +26,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} +import org.apache.spark.sql.{DataFrame, RelationalGroupedDataset, Row, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.types._ @@ -146,16 +146,26 @@ private[sql] object SQLUtils { packageNames: Array[Byte], broadcastVars: Array[Object], schema: StructType): DataFrame = { - val bv = broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]]) - val realSchema = - if (schema == null) { - SERIALIZED_R_DATA_SCHEMA - } else { - schema - } + val bv = broadcastVars.map(_.asInstanceOf[Broadcast[Object]]) + val realSchema = if (schema == null) SERIALIZED_R_DATA_SCHEMA else schema df.mapPartitionsInR(func, packageNames, bv, realSchema) } + /** + * The helper function for gapply() on R side. + */ + def gapply( + gd: RelationalGroupedDataset, + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Object], + schema: StructType): DataFrame = { + val bv = broadcastVars.map(_.asInstanceOf[Broadcast[Object]]) + val realSchema = if (schema == null) SERIALIZED_R_DATA_SCHEMA else schema + gd.flatMapGroupsInR(func, packageNames, bv, realSchema) + } + + def dfToCols(df: DataFrame): Array[Array[Any]] = { val localDF: Array[Row] = df.collect() val numCols = df.columns.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 60466e28307f76b6943f8101bed7cb7740b38d3f..8e2f2ed4f86b9dc957d11f388e8f6d7d2dd58ba7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -337,6 +337,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) => execution.MapPartitionsExec( execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil + case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => + execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, + data, objAttr, planLater(child)) :: Nil case logical.MapElements(f, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, in, out, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 5fced940b38d17e4d0ef2a921a446c8c04308209..c7e267152b5cdbaf5b95f48de39a126cb391e224 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -20,13 +20,17 @@ package org.apache.spark.sql.execution import scala.language.existentials import org.apache.spark.api.java.function.MapFunction +import org.apache.spark.api.r._ +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.sql.api.r.SQLUtils._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.types.{DataType, ObjectType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DataType, ObjectType, StructType} /** @@ -324,6 +328,72 @@ case class MapGroupsExec( } } +/** + * Groups the input rows together and calls the R function with each group and an iterator + * containing all elements in the group. + * The result of this function is flattened before being output. + */ +case class FlatMapGroupsInRExec( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + inputSchema: StructType, + outputSchema: StructType, + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + child: SparkPlan) extends UnaryExecNode with ObjectProducerExec { + + override def output: Seq[Attribute] = outputObjAttr :: Nil + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val isSerializedRData = + if (outputSchema == SERIALIZED_R_DATA_SCHEMA) true else false + val serializerForR = if (!isSerializedRData) { + SerializationFormats.ROW + } else { + SerializationFormats.BYTE + } + + child.execute().mapPartitionsInternal { iter => + val grouped = GroupedIterator(iter, groupingAttributes, child.output) + val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val runner = new RRunner[Array[Byte]]( + func, SerializationFormats.ROW, serializerForR, packageNames, broadcastVars, + isDataFrame = true, colNames = inputSchema.fieldNames, + mode = RRunnerModes.DATAFRAME_GAPPLY) + + val groupedRBytes = grouped.map { case (key, rowIter) => + val deserializedIter = rowIter.map(getValue) + val newIter = + deserializedIter.asInstanceOf[Iterator[Row]].map { row => rowToRBytes(row) } + val newKey = rowToRBytes(getKey(key).asInstanceOf[Row]) + (newKey, newIter) + } + + val outputIter = runner.compute(groupedRBytes, -1) + if (!isSerializedRData) { + val result = outputIter.map { bytes => bytesToRow(bytes, outputSchema) } + result.map(outputObject) + } else { + val result = outputIter.map { bytes => Row.fromSeq(Seq(bytes)) } + result.map(outputObject) + } + } + } +} + /** * Co-groups the data from left and right children, and calls the function with each group and 2 * iterators containing all elements in the group from left and right side. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala index 6c76328c74830fdd8affb714568b6f5eb91ad175..70539da348b0e0113cbc86ee752be433773611e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.r -import org.apache.spark.api.r.RRunner -import org.apache.spark.api.r.SerializationFormats +import org.apache.spark.api.r._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.api.r.SQLUtils._ import org.apache.spark.sql.Row @@ -55,7 +54,7 @@ private[sql] case class MapPartitionsRWrapper( val runner = new RRunner[Array[Byte]]( func, deserializer, serializer, packageNames, broadcastVars, - isDataFrame = true, colNames = colNames) + isDataFrame = true, colNames = colNames, mode = RRunnerModes.DATAFRAME_DAPPLY) // Partition index is ignored. Dataset has no support for mapPartitionsWithIndex. val outputIter = runner.compute(newIter, -1)