diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 88e1a508f37c43094eceafbaad39ad9e104ded15..22a4b5bf86ebd98c10f03bda3e01f0a870876fb3 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -452,7 +452,7 @@ dropTempTable <- function(sqlContext, tableName) { #' df <- read.df(sqlContext, "path/to/file.json", source = "json") #' } -read.df <- function(sqlContext, path = NULL, source = NULL, ...) { +read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[['path']] <- path @@ -462,15 +462,21 @@ read.df <- function(sqlContext, path = NULL, source = NULL, ...) { source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } - sdf <- callJMethod(sqlContext, "load", source, options) + if (!is.null(schema)) { + stopifnot(class(schema) == "structType") + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, + schema$jobj, options) + } else { + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options) + } dataFrame(sdf) } #' @aliases loadDF #' @export -loadDF <- function(sqlContext, path = NULL, source = NULL, ...) { - read.df(sqlContext, path, source, ...) +loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { + read.df(sqlContext, path, source, schema, ...) } #' Create an external table diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d2d82e791e8763bbef936773e33767aa8921ab5f..30edfc8a7bd9422d2934edff5a7b5e29cc3c41e8 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -504,6 +504,19 @@ test_that("read.df() from json file", { df <- read.df(sqlContext, jsonPath, "json") expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 3) + + # Check if we can apply a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(sqlContext, jsonPath, "json", schema) + expect_true(inherits(df1, "DataFrame")) + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Run the same with loadDF + df2 <- loadDF(sqlContext, jsonPath, "json", schema) + expect_true(inherits(df2, "DataFrame")) + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) }) test_that("write.df() as parquet file", { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 604f3124e23ae58b518b241498251e8e6874f7ad..43b62f0e822f859d7a92a7c1b851e4f588e041f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -139,4 +139,19 @@ private[r] object SQLUtils { case "ignore" => SaveMode.Ignore } } + + def loadDF( + sqlContext: SQLContext, + source: String, + options: java.util.Map[String, String]): DataFrame = { + sqlContext.read.format(source).options(options).load() + } + + def loadDF( + sqlContext: SQLContext, + source: String, + schema: StructType, + options: java.util.Map[String, String]): DataFrame = { + sqlContext.read.format(source).schema(schema).options(options).load() + } }