diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index e5d62a466cab6909da7ac6ddf4571bcb1fd91236..abb284d1e3dd9470277365ca07e7423aa0f55149 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -45,7 +45,7 @@ from py4j.java_collections import ListConverter, MapConverter from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ - CloudPickleSerializer + CloudPickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -1870,6 +1870,21 @@ class SchemaRDD(RDD): rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD() return SchemaRDD(rdd, self.sql_ctx) + def toJSON(self, use_unicode=False): + """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row. + + >>> srdd1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(srdd1, "table1") + >>> srdd2 = sqlCtx.sql( "SELECT * from table1") + >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' + True + >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1") + >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] + True + """ + rdd = self._jschema_rdd.baseSchemaRDD().toJSON() + return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) + def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 904a276ef3ffba8ae28cac4381ee80f2ef4e95f1..f8970cd3e6360a6d5c9bd2e20e281069a7027dcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,17 +17,19 @@ package org.apache.spark.sql -import java.util.{List => JList} - -import org.apache.spark.api.python.SerDeUtil +import java.util.{Map => JMap, List => JList} +import java.io.StringWriter import scala.collection.JavaConversions._ +import com.fasterxml.jackson.core.JsonFactory + import net.razorvine.pickle.Pickler import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext} import org.apache.spark.annotation.{AlphaComponent, Experimental} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.api.java.JavaSchemaRDD import org.apache.spark.sql.catalyst.ScalaReflection @@ -35,6 +37,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} import org.apache.spark.storage.StorageLevel @@ -131,6 +134,20 @@ class SchemaRDD( */ lazy val schema: StructType = queryExecution.analyzed.schema + /** + * Returns a new RDD with each row transformed to a JSON string. + * + * @group schema + */ + def toJSON: RDD[String] = { + val rowSchema = this.schema + this.mapPartitions { iter => + val jsonFactory = new JsonFactory() + iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) + } + } + + // ======================================================================= // Query DSL // ======================================================================= diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 78e8d908fe0c8ef2bc0a558113f6283c0ee15d83..ac4844f9b929035e322c17b1d0d09433b49513dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -125,6 +125,12 @@ class JavaSchemaRDD( // Transformations (return a new RDD) + /** + * Returns a new RDD with each row transformed to a JSON string. + */ + def toJSON(): JavaRDD[String] = + baseSchemaRDD.toJSON.toJavaRDD + /** * Return a new RDD that is reduced into `numPartitions` partitions. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index d9d7a3fea3963718964e932c273310936097fad6..ffb9548356d1dcb377ad616c4a1093fc189697f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -20,12 +20,15 @@ package org.apache.spark.sql.json import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.types.util.DataTypeConversions +import java.io.StringWriter + import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal import java.sql.{Date, Timestamp} import com.fasterxml.jackson.core.JsonProcessingException +import com.fasterxml.jackson.core.JsonFactory import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.rdd.RDD @@ -424,4 +427,61 @@ private[sql] object JsonRDD extends Logging { row } + + /** Transforms a single Row to JSON using Jackson + * + * @param jsonFactory a JsonFactory object to construct a JsonGenerator + * @param rowSchema the schema object used for conversion + * @param row The row to convert + */ + private[sql] def rowToJSON(rowSchema: StructType, jsonFactory: JsonFactory)(row: Row): String = { + val writer = new StringWriter() + val gen = jsonFactory.createGenerator(writer) + + def valWriter: (DataType, Any) => Unit = { + case (_, null) | (NullType, _) => gen.writeNull() + case (StringType, v: String) => gen.writeString(v) + case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) + case (IntegerType, v: Int) => gen.writeNumber(v) + case (ShortType, v: Short) => gen.writeNumber(v) + case (FloatType, v: Float) => gen.writeNumber(v) + case (DoubleType, v: Double) => gen.writeNumber(v) + case (LongType, v: Long) => gen.writeNumber(v) + case (DecimalType(), v: scala.math.BigDecimal) => gen.writeNumber(v.bigDecimal) + case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v) + case (ByteType, v: Byte) => gen.writeNumber(v.toInt) + case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) + case (BooleanType, v: Boolean) => gen.writeBoolean(v) + case (DateType, v) => gen.writeString(v.toString) + case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v) + + case (ArrayType(ty, _), v: Seq[_] ) => + gen.writeStartArray() + v.foreach(valWriter(ty,_)) + gen.writeEndArray() + + case (MapType(kv,vv, _), v: Map[_,_]) => + gen.writeStartObject + v.foreach { p => + gen.writeFieldName(p._1.toString) + valWriter(vv,p._2) + } + gen.writeEndObject + + case (StructType(ty), v: Seq[_]) => + gen.writeStartObject() + ty.zip(v).foreach { + case (_, null) => + case (field, v) => + gen.writeFieldName(field.name) + valWriter(field.dataType, v) + } + gen.writeEndObject() + } + + valWriter(rowSchema, row) + gen.close() + writer.toString + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index f8ca2c773d9ab527a0fa70fff43087f8db8a86d3..f088d413257a929cc36246f57eaea641df693e9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.{Row, SQLConf, QueryTest} +import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ @@ -779,4 +780,125 @@ class JsonSuite extends QueryTest { Seq(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil ) } + + test("SPARK-4228 SchemaRDD to JSON") + { + val schema1 = StructType( + StructField("f1", IntegerType, false) :: + StructField("f2", StringType, false) :: + StructField("f3", BooleanType, false) :: + StructField("f4", ArrayType(StringType), nullable = true) :: + StructField("f5", IntegerType, true) :: Nil) + + val rowRDD1 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v5 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) + } + + val schemaRDD1 = applySchema(rowRDD1, schema1) + schemaRDD1.registerTempTable("applySchema1") + val schemaRDD2 = schemaRDD1.toSchemaRDD + val result = schemaRDD2.toJSON.collect() + assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") + assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") + + val schema2 = StructType( + StructField("f1", StructType( + StructField("f11", IntegerType, false) :: + StructField("f12", BooleanType, false) :: Nil), false) :: + StructField("f2", MapType(StringType, IntegerType, true), false) :: Nil) + + val rowRDD2 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) + } + + val schemaRDD3 = applySchema(rowRDD2, schema2) + schemaRDD3.registerTempTable("applySchema2") + val schemaRDD4 = schemaRDD3.toSchemaRDD + val result2 = schemaRDD4.toJSON.collect() + + assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") + assert(result2(3) == "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") + + val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) + val primTable = jsonRDD(jsonSchemaRDD.toJSON) + primTable.registerTempTable("primativeTable") + checkAnswer( + sql("select * from primativeTable"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + "this is a simple string.") :: Nil + ) + + val complexJsonSchemaRDD = jsonRDD(complexFieldAndType1) + val compTable = jsonRDD(complexJsonSchemaRDD.toJSON) + compTable.registerTempTable("complexTable") + // Access elements of a primitive array. + checkAnswer( + sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"), + ("str1", "str2", null) :: Nil + ) + + // Access an array of null values. + checkAnswer( + sql("select arrayOfNull from complexTable"), + Seq(Seq(null, null, null, null)) :: Nil + ) + + // Access elements of a BigInteger array (we use DecimalType internally). + checkAnswer( + sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from complexTable"), + (BigDecimal("922337203685477580700"), BigDecimal("-922337203685477580800"), null) :: Nil + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray1[0], arrayOfArray1[1] from complexTable"), + (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray2[0], arrayOfArray2[1] from complexTable"), + (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil + ) + + // Access elements of an array inside a filed with the type of ArrayType(ArrayType). + checkAnswer( + sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from complexTable"), + ("str2", 2.1) :: Nil + ) + + // Access a struct and fields inside of it. + checkAnswer( + sql("select struct, struct.field1, struct.field2 from complexTable"), + Row( + Row(true, BigDecimal("92233720368547758070")), + true, + BigDecimal("92233720368547758070")) :: Nil + ) + + // Access an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"), + (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil + ) + + // Access elements of an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"), + (5, null) :: Nil + ) + + } }