Skip to content
Snippets Groups Projects
Commit 217db56b authored by Dongjoon Hyun's avatar Dongjoon Hyun Committed by Shivaram Venkataraman
Browse files

[SPARK-15294][R] Add `pivot` to SparkR

## What changes were proposed in this pull request?

This PR adds `pivot` function to SparkR for API parity. Since this PR is based on https://github.com/apache/spark/pull/13295 , mhnatiuk should be credited for the work he did.

## How was this patch tested?

Pass the Jenkins tests (including new testcase.)

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #13786 from dongjoon-hyun/SPARK-15294.
parent a46553cb
No related branches found
No related tags found
No related merge requests found
......@@ -294,6 +294,7 @@ exportMethods("%in%",
exportClasses("GroupedData")
exportMethods("agg")
exportMethods("pivot")
export("as.DataFrame",
"cacheTable",
......
......@@ -160,6 +160,10 @@ setGeneric("persist", function(x, newLevel) { standardGeneric("persist") })
# @export
setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")})
# @rdname pivot
# @export
setGeneric("pivot", function(x, colname, values = list()) { standardGeneric("pivot") })
# @rdname reduce
# @export
setGeneric("reduce", function(x, func) { standardGeneric("reduce") })
......
......@@ -134,6 +134,49 @@ methods <- c("avg", "max", "mean", "min", "sum")
# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", "stddev_pop",
# "variance", "var_samp", "var_pop"
#' Pivot a column of the GroupedData and perform the specified aggregation.
#'
#' Pivot a column of the GroupedData and perform the specified aggregation.
#' There are two versions of pivot function: one that requires the caller to specify the list
#' of distinct values to pivot on, and one that does not. The latter is more concise but less
#' efficient, because Spark needs to first compute the list of distinct values internally.
#'
#' @param x a GroupedData object
#' @param colname A column name
#' @param values A value or a list/vector of distinct values for the output columns.
#' @return GroupedData object
#' @rdname pivot
#' @name pivot
#' @export
#' @examples
#' \dontrun{
#' df <- createDataFrame(data.frame(
#' earnings = c(10000, 10000, 11000, 15000, 12000, 20000, 21000, 22000),
#' course = c("R", "Python", "R", "Python", "R", "Python", "R", "Python"),
#' period = c("1H", "1H", "2H", "2H", "1H", "1H", "2H", "2H"),
#' year = c(2015, 2015, 2015, 2015, 2016, 2016, 2016, 2016)
#' ))
#' group_sum <- sum(pivot(groupBy(df, "year"), "course"), "earnings")
#' group_min <- min(pivot(groupBy(df, "year"), "course", "R"), "earnings")
#' group_max <- max(pivot(groupBy(df, "year"), "course", c("Python", "R")), "earnings")
#' group_mean <- mean(pivot(groupBy(df, "year"), "course", list("Python", "R")), "earnings")
#' }
#' @note pivot since 2.0.0
setMethod("pivot",
signature(x = "GroupedData", colname = "character"),
function(x, colname, values = list()){
stopifnot(length(colname) == 1)
if (length(values) == 0) {
result <- callJMethod(x@sgd, "pivot", colname)
} else {
if (length(values) > length(unique(values))) {
stop("Values are not unique")
}
result <- callJMethod(x@sgd, "pivot", colname, as.list(values))
}
groupedData(result)
})
createMethod <- function(name) {
setMethod(name,
signature(x = "GroupedData"),
......
......@@ -1398,6 +1398,31 @@ test_that("group by, agg functions", {
unlink(jsonPath3)
})
test_that("pivot GroupedData column", {
df <- createDataFrame(data.frame(
earnings = c(10000, 10000, 11000, 15000, 12000, 20000, 21000, 22000),
course = c("R", "Python", "R", "Python", "R", "Python", "R", "Python"),
year = c(2013, 2013, 2014, 2014, 2015, 2015, 2016, 2016)
))
sum1 <- collect(sum(pivot(groupBy(df, "year"), "course"), "earnings"))
sum2 <- collect(sum(pivot(groupBy(df, "year"), "course", c("Python", "R")), "earnings"))
sum3 <- collect(sum(pivot(groupBy(df, "year"), "course", list("Python", "R")), "earnings"))
sum4 <- collect(sum(pivot(groupBy(df, "year"), "course", "R"), "earnings"))
correct_answer <- data.frame(
year = c(2013, 2014, 2015, 2016),
Python = c(10000, 15000, 20000, 22000),
R = c(10000, 11000, 12000, 21000)
)
expect_equal(sum1, correct_answer)
expect_equal(sum2, correct_answer)
expect_equal(sum3, correct_answer)
expect_equal(sum4, correct_answer[, c("year", "R")])
expect_error(collect(sum(pivot(groupBy(df, "year"), "course", c("R", "R")), "earnings")))
expect_error(collect(sum(pivot(groupBy(df, "year"), "course", list("R", "R")), "earnings")))
})
test_that("arrange() and orderBy() on a DataFrame", {
df <- read.json(jsonPath)
sorted <- arrange(df, df$age)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment