diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a82ded9c51facb77f9d144d9687da181c4b5801a..81b4e6b91d8a2e7913af54b758fc57c2448d2f49 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -822,21 +822,21 @@ setMethod("collect", # Get a column of complex type returns a list. # Get a cell from a column of complex type returns a list instead of a vector. col <- listCols[[colIndex]] - colName <- dtypes[[colIndex]][[1]] if (length(col) <= 0) { - df[[colName]] <- col + df[[colIndex]] <- col } else { colType <- dtypes[[colIndex]][[2]] # Note that "binary" columns behave like complex types. if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { vec <- do.call(c, col) stopifnot(class(vec) != "list") - df[[colName]] <- vec + df[[colIndex]] <- vec } else { - df[[colName]] <- col + df[[colIndex]] <- col } } } + names(df) <- names(x) df } }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 92ec82096c6df5b4fdc7f58a2ac3115ff7e35881..1e7cb5409970381fb0aa0722192b06054a149b6f 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -530,6 +530,11 @@ test_that("collect() returns a data.frame", { expect_equal(names(rdf)[1], "age") expect_equal(nrow(rdf), 0) expect_equal(ncol(rdf), 2) + + # collect() correctly handles multiple columns with same name + df <- createDataFrame(sqlContext, list(list(1, 2)), schema = c("name", "name")) + ldf <- collect(df) + expect_equal(names(ldf), c("name", "name")) }) test_that("limit() returns DataFrame with the correct number of rows", { @@ -1197,6 +1202,7 @@ test_that("join() and merge() on a DataFrame", { joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) expect_equal(count(joined), 12) + expect_equal(names(collect(joined)), c("age", "name", "name", "test")) joined2 <- join(df, df2, df$name == df2$name) expect_equal(names(joined2), c("age", "name", "name", "test"))