diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fefe25b1480a335eda3d4e130756723ee84dd42e..5bca4105fccd5db51924d32f1aa30c78644de683 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -415,7 +415,7 @@ setMethod("coltypes", type <- PRIMITIVE_TYPES[[specialtype]] } } - type + type[[1]] }) # Find which types don't have mapping to R @@ -1136,6 +1136,7 @@ setMethod("collect", if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { vec <- do.call(c, col) stopifnot(class(vec) != "list") + class(vec) <- PRIMITIVE_TYPES[[colType]] df[[colIndex]] <- vec } else { df[[colIndex]] <- col diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index abca703617c7bcb0b77f2eded74220d2e93f00ef..ade0f05c02542381ee7e9f291de0e888f1792b6c 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -29,7 +29,7 @@ PRIMITIVE_TYPES <- as.environment(list( "string" = "character", "binary" = "raw", "boolean" = "logical", - "timestamp" = "POSIXct", + "timestamp" = c("POSIXct", "POSIXt"), "date" = "Date", # following types are not SQL types returned by dtypes(). They are listed here for usage # by checkType() in schema.R. diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 233a20c3d3866b68fccbf10400432e467956378f..1494ebb3de2544ddae69b77ff149860e390bdce3 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1306,9 +1306,9 @@ test_that("column functions", { # Test first(), last() df <- read.json(jsonPath) - expect_equal(collect(select(df, first(df$age)))[[1]], NA) + expect_equal(collect(select(df, first(df$age)))[[1]], NA_real_) expect_equal(collect(select(df, first(df$age, TRUE)))[[1]], 30) - expect_equal(collect(select(df, first("age")))[[1]], NA) + expect_equal(collect(select(df, first("age")))[[1]], NA_real_) expect_equal(collect(select(df, first("age", TRUE)))[[1]], 30) expect_equal(collect(select(df, last(df$age)))[[1]], 19) expect_equal(collect(select(df, last(df$age, TRUE)))[[1]], 19) @@ -2777,6 +2777,44 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume "Unnamed arguments ignored: 2, 3, a.") }) +test_that("Collect on DataFrame when NAs exists at the top of a timestamp column", { + ldf <- data.frame(col1 = c(0, 1, 2), + col2 = c(as.POSIXct("2017-01-01 00:00:01"), + NA, + as.POSIXct("2017-01-01 12:00:01")), + col3 = c(as.POSIXlt("2016-01-01 00:59:59"), + NA, + as.POSIXlt("2016-01-01 12:01:01"))) + sdf1 <- createDataFrame(ldf) + ldf1 <- collect(sdf1) + expect_equal(dtypes(sdf1), list(c("col1", "double"), + c("col2", "timestamp"), + c("col3", "timestamp"))) + expect_equal(class(ldf1$col1), "numeric") + expect_equal(class(ldf1$col2), c("POSIXct", "POSIXt")) + expect_equal(class(ldf1$col3), c("POSIXct", "POSIXt")) + + # Columns with NAs at the top + sdf2 <- filter(sdf1, "col1 > 1") + ldf2 <- collect(sdf2) + expect_equal(dtypes(sdf2), list(c("col1", "double"), + c("col2", "timestamp"), + c("col3", "timestamp"))) + expect_equal(class(ldf2$col1), "numeric") + expect_equal(class(ldf2$col2), c("POSIXct", "POSIXt")) + expect_equal(class(ldf2$col3), c("POSIXct", "POSIXt")) + + # Columns with only NAs, the type will also be cast to PRIMITIVE_TYPE + sdf3 <- filter(sdf1, "col1 == 0") + ldf3 <- collect(sdf3) + expect_equal(dtypes(sdf3), list(c("col1", "double"), + c("col2", "timestamp"), + c("col3", "timestamp"))) + expect_equal(class(ldf3$col1), "numeric") + expect_equal(class(ldf3$col2), c("POSIXct", "POSIXt")) + expect_equal(class(ldf3$col3), c("POSIXct", "POSIXt")) +}) + unlink(parquetPath) unlink(orcPath) unlink(jsonPath)