From 9e785079b6ed4ea691c3c14c762a7f73fb6254bf Mon Sep 17 00:00:00 2001 From: Sun Rui <rui.sun@intel.com> Date: Thu, 28 Apr 2016 09:33:58 -0700 Subject: [PATCH] [SPARK-12235][SPARKR] Enhance mutate() to support replace existing columns. Make the behavior of mutate more consistent with that in dplyr, besides support for replacing existing columns. 1. Throw error message when there are duplicated column names in the DataFrame being mutated. 2. when there are duplicated column names in specified columns by arguments, the last column of the same name takes effect. Author: Sun Rui <rui.sun@intel.com> Closes #10220 from sun-rui/SPARK-12235. --- R/pkg/R/DataFrame.R | 60 +++++++++++++++++++---- R/pkg/inst/tests/testthat/test_sparkSQL.R | 18 +++++++ 2 files changed, 69 insertions(+), 9 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 48ac1b06f6..a741fdf709 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1431,11 +1431,11 @@ setMethod("withColumn", #' Mutate #' -#' Return a new SparkDataFrame with the specified columns added. +#' Return a new SparkDataFrame with the specified columns added or replaced. #' #' @param .data A SparkDataFrame #' @param col a named argument of the form name = col -#' @return A new SparkDataFrame with the new columns added. +#' @return A new SparkDataFrame with the new columns added or replaced. #' @family SparkDataFrame functions #' @rdname mutate #' @name mutate @@ -1450,23 +1450,65 @@ setMethod("withColumn", #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 #' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2) +#' +#' df <- createDataFrame(sqlContext, +#' list(list("Andy", 30L), list("Justin", 19L)), c("name", "age")) +#' # Replace the "age" column +#' df1 <- mutate(df, age = df$age + 1L) #' } setMethod("mutate", signature(.data = "SparkDataFrame"), function(.data, ...) { x <- .data cols <- list(...) - stopifnot(length(cols) > 0) - stopifnot(class(cols[[1]]) == "Column") + if (length(cols) <= 0) { + return(x) + } + + lapply(cols, function(col) { + stopifnot(class(col) == "Column") + }) + + # Check if there is any duplicated column name in the DataFrame + dfCols <- columns(x) + if (length(unique(dfCols)) != length(dfCols)) { + stop("Error: found duplicated column name in the DataFrame") + } + + # TODO: simplify the implementation of this method after SPARK-12225 is resolved. + + # For named arguments, use the names for arguments as the column names + # For unnamed arguments, use the argument symbols as the column names + args <- sapply(substitute(list(...))[-1], deparse) ns <- names(cols) if (!is.null(ns)) { - for (n in ns) { - if (n != "") { - cols[[n]] <- alias(cols[[n]], n) + lapply(seq_along(args), function(i) { + if (ns[[i]] != "") { + args[[i]] <<- ns[[i]] } - } + }) + } + ns <- args + + # The last column of the same name in the specific columns takes effect + deDupCols <- list() + for (i in 1:length(cols)) { + deDupCols[[ns[[i]]]] <- alias(cols[[i]], ns[[i]]) } - do.call(select, c(x, x$"*", cols)) + + # Construct the column list for projection + colList <- lapply(dfCols, function(col) { + if (!is.null(deDupCols[[col]])) { + # Replace existing column + tmpCol <- deDupCols[[col]] + deDupCols[[col]] <<- NULL + tmpCol + } else { + col(col) + } + }) + + do.call(select, c(x, colList, deDupCols)) }) #' @export diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 95d6cb8875..7058265ea3 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1581,6 +1581,24 @@ test_that("mutate(), transform(), rename() and names()", { expect_equal(columns(newDF)[3], "newAge") expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) + newDF <- mutate(df, age = df$age + 2, newAge = df$age + 3) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 33) + expect_equal(first(filter(newDF, df$name != "Michael"))$age, 32) + + newDF <- mutate(df, age = df$age + 2, newAge = df$age + 3, + age = df$age + 4, newAge = df$age + 5) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 35) + expect_equal(first(filter(newDF, df$name != "Michael"))$age, 34) + + newDF <- mutate(df, df$age + 3) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[[3]], "df$age + 3") + expect_equal(first(filter(newDF, df$name != "Michael"))[[3]], 33) + newDF2 <- rename(df, newerAge = df$age) expect_equal(length(columns(newDF2)), 2) expect_equal(columns(newDF2)[1], "newerAge") -- GitLab