diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 6cf628e3007de9ed661b638b00215aaad0a18a77..88f18613fd7b1461aac96566cefa672261dd7e5f 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -57,8 +57,10 @@ readTypedObject <- function(con, type) { readString <- function(con) { stringLen <- readInt(con) - string <- readBin(con, raw(), stringLen, endian = "big") - rawToChar(string) + raw <- readBin(con, raw(), stringLen, endian = "big") + string <- rawToChar(raw) + Encoding(string) <- "UTF-8" + string } readInt <- function(con) { diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index e3676f57f907f4c96202ca7a643d287ae850d142..91e6b3e5609b5bc54235d28d133a6c8cee7becf3 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -79,7 +79,7 @@ writeJobj <- function(con, value) { writeString <- function(con, value) { utfVal <- enc2utf8(value) writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1)) - writeBin(utfVal, con, endian = "big") + writeBin(utfVal, con, endian = "big", useBytes=TRUE) } writeInt <- function(con, value) { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 0da5e386547327dc8084c21f016cfe8e33bdb954..6d331f9883d5537bccfdffa5d6e77ceb9b858020 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -431,6 +431,32 @@ test_that("collect() and take() on a DataFrame return the same number of rows an expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) +test_that("collect() support Unicode characters", { + markUtf8 <- function(s) { + Encoding(s) <- "UTF-8" + s + } + + lines <- c("{\"name\":\"안녕하세요\"}", + "{\"name\":\"您好\", \"age\":30}", + "{\"name\":\"ã“ã‚“ã«ã¡ã¯\", \"age\":19}", + "{\"name\":\"Xin chà o\"}") + + jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPath) + + df <- read.df(sqlContext, jsonPath, "json") + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_equal(rdf$name[1], markUtf8("안녕하세요")) + expect_equal(rdf$name[2], markUtf8("您好")) + expect_equal(rdf$name[3], markUtf8("ã“ã‚“ã«ã¡ã¯")) + expect_equal(rdf$name[4], markUtf8("Xin chà o")) + + df1 <- createDataFrame(sqlContext, rdf) + expect_equal(collect(where(df1, df1$name == markUtf8("您好")))$name, markUtf8("您好")) +}) + test_that("multiple pipeline transformations result in an RDD with the correct values", { df <- jsonFile(sqlContext, jsonPath) first <- lapply(df, function(row) { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 26ad4f1d4697ed26cdad10316ba4a805ee9f6b02..190e193427af8fb598a3b78d1cb46a684ef829d4 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -329,12 +329,11 @@ private[spark] object SerDe { out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) } - // NOTE: Only works for ASCII right now def writeString(out: DataOutputStream, value: String): Unit = { - val len = value.length - out.writeInt(len + 1) // For the \0 - out.writeBytes(value) - out.writeByte(0) + val utf8 = value.getBytes("UTF-8") + val len = utf8.length + out.writeInt(len) + out.write(utf8, 0, len) } def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = {