From 5011f264fb53705c528250bd055acbc2eca2baaa Mon Sep 17 00:00:00 2001
From: Sun Rui <rui.sun@intel.com>
Date: Thu, 3 Dec 2015 21:11:10 -0800
Subject: [PATCH] [SPARK-12104][SPARKR] collect() does not handle multiple
 columns with same name.

Author: Sun Rui <rui.sun@intel.com>

Closes #10118 from sun-rui/SPARK-12104.
---
 R/pkg/R/DataFrame.R              | 8 ++++----
 R/pkg/inst/tests/test_sparkSQL.R | 6 ++++++
 2 files changed, 10 insertions(+), 4 deletions(-)

diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index a82ded9c51..81b4e6b91d 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -822,21 +822,21 @@ setMethod("collect",
                 # Get a column of complex type returns a list.
                 # Get a cell from a column of complex type returns a list instead of a vector.
                 col <- listCols[[colIndex]]
-                colName <- dtypes[[colIndex]][[1]]
                 if (length(col) <= 0) {
-                  df[[colName]] <- col
+                  df[[colIndex]] <- col
                 } else {
                   colType <- dtypes[[colIndex]][[2]]
                   # Note that "binary" columns behave like complex types.
                   if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") {
                     vec <- do.call(c, col)
                     stopifnot(class(vec) != "list")
-                    df[[colName]] <- vec
+                    df[[colIndex]] <- vec
                   } else {
-                    df[[colName]] <- col
+                    df[[colIndex]] <- col
                   }
                 }
               }
+              names(df) <- names(x)
               df
             }
           })
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 92ec82096c..1e7cb54099 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -530,6 +530,11 @@ test_that("collect() returns a data.frame", {
   expect_equal(names(rdf)[1], "age")
   expect_equal(nrow(rdf), 0)
   expect_equal(ncol(rdf), 2)
+
+  # collect() correctly handles multiple columns with same name
+  df <- createDataFrame(sqlContext, list(list(1, 2)), schema = c("name", "name"))
+  ldf <- collect(df)
+  expect_equal(names(ldf), c("name", "name"))
 })
 
 test_that("limit() returns DataFrame with the correct number of rows", {
@@ -1197,6 +1202,7 @@ test_that("join() and merge() on a DataFrame", {
   joined <- join(df, df2)
   expect_equal(names(joined), c("age", "name", "name", "test"))
   expect_equal(count(joined), 12)
+  expect_equal(names(collect(joined)), c("age", "name", "name", "test"))
 
   joined2 <- join(df, df2, df$name == df2$name)
   expect_equal(names(joined2), c("age", "name", "name", "test"))
-- 
GitLab