diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ed4ccdb4c8d4f304bd035635b651a3da43849853..b28ecb753f2263609d92e12c87a7dff7cf8e004c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2489,12 +2489,12 @@ class Dataset[T] private[sql]( val rdd: RDD[String] = queryExecution.toRdd.mapPartitions { iter => val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records - val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + val gen = new JacksonGenerator(rowSchema, writer) new Iterator[String] { override def hasNext: Boolean = iter.hasNext override def next(): String = { - JacksonGenerator(rowSchema, gen)(iter.next()) + gen.write(iter.next()) gen.flush() val json = writer.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 8b920ecafaeed291c4733b7b05834a8f21a31ffd..23f4a55491d288c1d6ae4a6f96c787759a74f5c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -17,74 +17,174 @@ package org.apache.spark.sql.execution.datasources.json +import java.io.Writer + import com.fasterxml.jackson.core._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData} import org.apache.spark.sql.types._ -private[sql] object JacksonGenerator { - /** Transforms a single InternalRow to JSON using Jackson - * - * TODO: make the code shared with the other apply method. +private[sql] class JacksonGenerator(schema: StructType, writer: Writer) { + // A `ValueWriter` is responsible for writing a field of an `InternalRow` to appropriate + // JSON data. Here we are using `SpecializedGetters` rather than `InternalRow` so that + // we can directly access data in `ArrayData` without the help of `SpecificMutableRow`. + private type ValueWriter = (SpecializedGetters, Int) => Unit + + // `ValueWriter`s for all fields of the schema + private val rootFieldWriters: Array[ValueWriter] = schema.map(_.dataType).map(makeWriter).toArray + + private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + + private def makeWriter(dataType: DataType): ValueWriter = dataType match { + case NullType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNull() + + case BooleanType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeBoolean(row.getBoolean(ordinal)) + + case ByteType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getByte(ordinal)) + + case ShortType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getShort(ordinal)) + + case IntegerType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getInt(ordinal)) + + case LongType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getLong(ordinal)) + + case FloatType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getFloat(ordinal)) + + case DoubleType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getDouble(ordinal)) + + case StringType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeString(row.getUTF8String(ordinal).toString) + + case TimestampType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeString(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal)).toString) + + case DateType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeString(DateTimeUtils.toJavaDate(row.getInt(ordinal)).toString) + + case BinaryType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeBinary(row.getBinary(ordinal)) + + case dt: DecimalType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getDecimal(ordinal, dt.precision, dt.scale).toJavaBigDecimal) + + case st: StructType => + val fieldWriters = st.map(_.dataType).map(makeWriter) + (row: SpecializedGetters, ordinal: Int) => + writeObject(writeFields(row.getStruct(ordinal, st.length), st, fieldWriters)) + + case at: ArrayType => + val elementWriter = makeWriter(at.elementType) + (row: SpecializedGetters, ordinal: Int) => + writeArray(writeArrayData(row.getArray(ordinal), elementWriter)) + + case mt: MapType => + val valueWriter = makeWriter(mt.valueType) + (row: SpecializedGetters, ordinal: Int) => + writeObject(writeMapData(row.getMap(ordinal), mt, valueWriter)) + + // For UDT values, they should be in the SQL type's corresponding value type. + // We should not see values in the user-defined class at here. + // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is + // an ArrayData at here, instead of a Vector. + case t: UserDefinedType[_] => + makeWriter(t.sqlType) + + case _ => + (row: SpecializedGetters, ordinal: Int) => + val v = row.get(ordinal, dataType) + sys.error(s"Failed to convert value $v (class of ${v.getClass}}) " + + s"with the type of $dataType to JSON.") + } + + private def writeObject(f: => Unit): Unit = { + gen.writeStartObject() + f + gen.writeEndObject() + } + + private def writeFields( + row: InternalRow, schema: StructType, fieldWriters: Seq[ValueWriter]): Unit = { + var i = 0 + while (i < row.numFields) { + val field = schema(i) + if (!row.isNullAt(i)) { + gen.writeFieldName(field.name) + fieldWriters(i).apply(row, i) + } + i += 1 + } + } + + private def writeArray(f: => Unit): Unit = { + gen.writeStartArray() + f + gen.writeEndArray() + } + + private def writeArrayData( + array: ArrayData, fieldWriter: ValueWriter): Unit = { + var i = 0 + while (i < array.numElements()) { + if (!array.isNullAt(i)) { + fieldWriter.apply(array, i) + } else { + gen.writeNull() + } + i += 1 + } + } + + private def writeMapData( + map: MapData, mapType: MapType, fieldWriter: ValueWriter): Unit = { + val keyArray = map.keyArray() + val valueArray = map.valueArray() + var i = 0 + while (i < map.numElements()) { + gen.writeFieldName(keyArray.get(i, mapType.keyType).toString) + if (!valueArray.isNullAt(i)) { + fieldWriter.apply(valueArray, i) + } else { + gen.writeNull() + } + i += 1 + } + } + + def close(): Unit = gen.close() + + def flush(): Unit = gen.flush() + + /** + * Transforms a single InternalRow to JSON using Jackson * - * @param rowSchema the schema object used for conversion - * @param gen a JsonGenerator object * @param row The row to convert */ - def apply(rowSchema: StructType, gen: JsonGenerator)(row: InternalRow): Unit = { - def valWriter: (DataType, Any) => Unit = { - case (_, null) | (NullType, _) => gen.writeNull() - case (StringType, v) => gen.writeString(v.toString) - case (TimestampType, v: Long) => gen.writeString(DateTimeUtils.toJavaTimestamp(v).toString) - case (IntegerType, v: Int) => gen.writeNumber(v) - case (ShortType, v: Short) => gen.writeNumber(v) - case (FloatType, v: Float) => gen.writeNumber(v) - case (DoubleType, v: Double) => gen.writeNumber(v) - case (LongType, v: Long) => gen.writeNumber(v) - case (DecimalType(), v: Decimal) => gen.writeNumber(v.toJavaBigDecimal) - case (ByteType, v: Byte) => gen.writeNumber(v.toInt) - case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) - case (BooleanType, v: Boolean) => gen.writeBoolean(v) - case (DateType, v: Int) => gen.writeString(DateTimeUtils.toJavaDate(v).toString) - // For UDT values, they should be in the SQL type's corresponding value type. - // We should not see values in the user-defined class at here. - // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is - // an ArrayData at here, instead of a Vector. - case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v) - - case (ArrayType(ty, _), v: ArrayData) => - gen.writeStartArray() - v.foreach(ty, (_, value) => valWriter(ty, value)) - gen.writeEndArray() - - case (MapType(kt, vt, _), v: MapData) => - gen.writeStartObject() - v.foreach(kt, vt, { (k, v) => - gen.writeFieldName(k.toString) - valWriter(vt, v) - }) - gen.writeEndObject() - - case (StructType(ty), v: InternalRow) => - gen.writeStartObject() - var i = 0 - while (i < ty.length) { - val field = ty(i) - val value = v.get(i, field.dataType) - if (value != null) { - gen.writeFieldName(field.name) - valWriter(field.dataType, value) - } - i += 1 - } - gen.writeEndObject() - - case (dt, v) => - sys.error( - s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.") + def write(row: InternalRow): Unit = { + writeObject { + writeFields(row, schema, rootFieldWriters) } - - valWriter(rowSchema, row) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 86aef1f7d4411c9d1bc9c54ef539ea21bdcf7a07..adca8d7af0bd8e6955a66ca6579dd337a524d330 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.json import java.io.CharArrayWriter -import com.fasterxml.jackson.core.JsonFactory import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, NullWritable, Text} @@ -162,7 +161,7 @@ private[json] class JsonOutputWriter( private[this] val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records - private[this] val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + private[this] val gen = new JacksonGenerator(dataSchema, writer) private[this] val result = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { @@ -181,7 +180,7 @@ private[json] class JsonOutputWriter( override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") override protected[sql] def writeInternal(row: InternalRow): Unit = { - JacksonGenerator(dataSchema, gen)(row) + gen.write(row) gen.flush() result.set(writer.toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6c72019702c3d51c25e0349ffe26d377db7485f9..a09f61aba9d391d6da68862f1af8d1df8ff35f0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -21,10 +21,7 @@ import java.io.{File, StringWriter} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import scala.collection.JavaConverters._ - import com.fasterxml.jackson.core.JsonFactory -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec