Skip to content
Snippets Groups Projects
Commit 82464fb2 authored by Stephen De Gennaro's avatar Stephen De Gennaro Committed by Yin Huai
Browse files

[SPARK-10947] [SQL] With schema inference from JSON into a Dataframe, add...

[SPARK-10947] [SQL] With schema inference from JSON into a Dataframe, add option to infer all primitive object types as strings

Currently, when a schema is inferred from a JSON file using sqlContext.read.json, the primitive object types are inferred as string, long, boolean, etc.

However, if the inferred type is too specific (JSON obviously does not enforce types itself), this can cause issues with merging dataframe schemas.

This pull request adds the option "primitivesAsString" to the JSON DataFrameReader which when true (defaults to false if not set) will infer all primitives as strings.

Below is an example usage of this new functionality.
```
val jsonDf = sqlContext.read.option("primitivesAsString", "true").json(sampleJsonFile)

scala> jsonDf.printSchema()
root
|-- bigInteger: string (nullable = true)
|-- boolean: string (nullable = true)
|-- double: string (nullable = true)
|-- integer: string (nullable = true)
|-- long: string (nullable = true)
|-- null: string (nullable = true)
|-- string: string (nullable = true)
```

Author: Stephen De Gennaro <stepheng@realitymine.com>

Closes #9249 from stephend-realitymine/stephend-primitives.
parent d4c397a6
No related branches found
No related tags found
No related merge requests found
......@@ -256,8 +256,16 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*/
def json(jsonRDD: RDD[String]): DataFrame = {
val samplingRatio = extraOptions.getOrElse("samplingRatio", "1.0").toDouble
val primitivesAsString = extraOptions.getOrElse("primitivesAsString", "false").toBoolean
sqlContext.baseRelationToDataFrame(
new JSONRelation(Some(jsonRDD), samplingRatio, userSpecifiedSchema, None, None)(sqlContext))
new JSONRelation(
Some(jsonRDD),
samplingRatio,
primitivesAsString,
userSpecifiedSchema,
None,
None)(sqlContext)
)
}
/**
......
......@@ -35,7 +35,8 @@ private[sql] object InferSchema {
def apply(
json: RDD[String],
samplingRatio: Double = 1.0,
columnNameOfCorruptRecords: String): StructType = {
columnNameOfCorruptRecords: String,
primitivesAsString: Boolean = false): StructType = {
require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0")
val schemaData = if (samplingRatio > 0.99) {
json
......@@ -50,7 +51,7 @@ private[sql] object InferSchema {
try {
Utils.tryWithResource(factory.createParser(row)) { parser =>
parser.nextToken()
inferField(parser)
inferField(parser, primitivesAsString)
}
} catch {
case _: JsonParseException =>
......@@ -70,14 +71,14 @@ private[sql] object InferSchema {
/**
* Infer the type of a json document from the parser's token stream
*/
private def inferField(parser: JsonParser): DataType = {
private def inferField(parser: JsonParser, primitivesAsString: Boolean): DataType = {
import com.fasterxml.jackson.core.JsonToken._
parser.getCurrentToken match {
case null | VALUE_NULL => NullType
case FIELD_NAME =>
parser.nextToken()
inferField(parser)
inferField(parser, primitivesAsString)
case VALUE_STRING if parser.getTextLength < 1 =>
// Zero length strings and nulls have special handling to deal
......@@ -92,7 +93,10 @@ private[sql] object InferSchema {
case START_OBJECT =>
val builder = Seq.newBuilder[StructField]
while (nextUntil(parser, END_OBJECT)) {
builder += StructField(parser.getCurrentName, inferField(parser), nullable = true)
builder += StructField(
parser.getCurrentName,
inferField(parser, primitivesAsString),
nullable = true)
}
StructType(builder.result().sortBy(_.name))
......@@ -103,11 +107,15 @@ private[sql] object InferSchema {
// the type as we pass through all JSON objects.
var elementType: DataType = NullType
while (nextUntil(parser, END_ARRAY)) {
elementType = compatibleType(elementType, inferField(parser))
elementType = compatibleType(elementType, inferField(parser, primitivesAsString))
}
ArrayType(elementType)
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if primitivesAsString => StringType
case (VALUE_TRUE | VALUE_FALSE) if primitivesAsString => StringType
case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
import JsonParser.NumberType._
parser.getNumberType match {
......
......@@ -52,14 +52,23 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
partitionColumns: Option[StructType],
parameters: Map[String, String]): HadoopFsRelation = {
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
val primitivesAsString = parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false)
new JSONRelation(None, samplingRatio, dataSchema, None, partitionColumns, paths)(sqlContext)
new JSONRelation(
None,
samplingRatio,
primitivesAsString,
dataSchema,
None,
partitionColumns,
paths)(sqlContext)
}
}
private[sql] class JSONRelation(
val inputRDD: Option[RDD[String]],
val samplingRatio: Double,
val primitivesAsString: Boolean,
val maybeDataSchema: Option[StructType],
val maybePartitionSpec: Option[PartitionSpec],
override val userDefinedPartitionColumns: Option[StructType],
......@@ -105,7 +114,8 @@ private[sql] class JSONRelation(
InferSchema(
inputRDD.getOrElse(createBaseRdd(files)),
samplingRatio,
sqlContext.conf.columnNameOfCorruptRecord)
sqlContext.conf.columnNameOfCorruptRecord,
primitivesAsString)
}
checkConstraints(jsonSchema)
......
......@@ -632,6 +632,136 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
)
}
test("Loading a JSON dataset primitivesAsString returns schema with primitive types as strings") {
val dir = Utils.createTempDir()
dir.delete()
val path = dir.getCanonicalPath
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(path)
val expectedSchema = StructType(
StructField("bigInteger", StringType, true) ::
StructField("boolean", StringType, true) ::
StructField("double", StringType, true) ::
StructField("integer", StringType, true) ::
StructField("long", StringType, true) ::
StructField("null", StringType, true) ::
StructField("string", StringType, true) :: Nil)
assert(expectedSchema === jsonDF.schema)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
sql("select * from jsonTable"),
Row("92233720368547758070",
"true",
"1.7976931348623157E308",
"10",
"21474836470",
null,
"this is a simple string.")
)
}
test("Loading a JSON dataset primitivesAsString returns complex fields as strings") {
val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(complexFieldAndType1)
val expectedSchema = StructType(
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) ::
StructField("arrayOfArray2", ArrayType(ArrayType(StringType, true), true), true) ::
StructField("arrayOfBigInteger", ArrayType(StringType, true), true) ::
StructField("arrayOfBoolean", ArrayType(StringType, true), true) ::
StructField("arrayOfDouble", ArrayType(StringType, true), true) ::
StructField("arrayOfInteger", ArrayType(StringType, true), true) ::
StructField("arrayOfLong", ArrayType(StringType, true), true) ::
StructField("arrayOfNull", ArrayType(StringType, true), true) ::
StructField("arrayOfString", ArrayType(StringType, true), true) ::
StructField("arrayOfStruct", ArrayType(
StructType(
StructField("field1", StringType, true) ::
StructField("field2", StringType, true) ::
StructField("field3", StringType, true) :: Nil), true), true) ::
StructField("struct", StructType(
StructField("field1", StringType, true) ::
StructField("field2", StringType, true) :: Nil), true) ::
StructField("structWithArrayFields", StructType(
StructField("field1", ArrayType(StringType, true), true) ::
StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil)
assert(expectedSchema === jsonDF.schema)
jsonDF.registerTempTable("jsonTable")
// Access elements of a primitive array.
checkAnswer(
sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"),
Row("str1", "str2", null)
)
// Access an array of null values.
checkAnswer(
sql("select arrayOfNull from jsonTable"),
Row(Seq(null, null, null, null))
)
// Access elements of a BigInteger array (we use DecimalType internally).
checkAnswer(
sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"),
Row("922337203685477580700", "-922337203685477580800", null)
)
// Access elements of an array of arrays.
checkAnswer(
sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"),
Row(Seq("1", "2", "3"), Seq("str1", "str2"))
)
// Access elements of an array of arrays.
checkAnswer(
sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"),
Row(Seq("1", "2", "3"), Seq("1.1", "2.1", "3.1"))
)
// Access elements of an array inside a filed with the type of ArrayType(ArrayType).
checkAnswer(
sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"),
Row("str2", "2.1")
)
// Access elements of an array of structs.
checkAnswer(
sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " +
"from jsonTable"),
Row(
Row("true", "str1", null),
Row("false", null, null),
Row(null, null, null),
null)
)
// Access a struct and fields inside of it.
checkAnswer(
sql("select struct, struct.field1, struct.field2 from jsonTable"),
Row(
Row("true", "92233720368547758070"),
"true",
"92233720368547758070") :: Nil
)
// Access an array field of a struct.
checkAnswer(
sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"),
Row(Seq("4", "5", "6"), Seq("str1", "str2"))
)
// Access elements of an array field of a struct.
checkAnswer(
sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"),
Row("5", null)
)
}
test("Loading a JSON dataset from a text file with SQL") {
val dir = Utils.createTempDir()
dir.delete()
......@@ -960,9 +1090,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val jsonDF = sqlContext.read.json(primitiveFieldAndType)
val primTable = sqlContext.read.json(jsonDF.toJSON)
primTable.registerTempTable("primativeTable")
primTable.registerTempTable("primitiveTable")
checkAnswer(
sql("select * from primativeTable"),
sql("select * from primitiveTable"),
Row(new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
......@@ -1039,24 +1169,28 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val relation0 = new JSONRelation(
Some(empty),
1.0,
false,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
None, None)(sqlContext)
val logicalRelation0 = LogicalRelation(relation0)
val relation1 = new JSONRelation(
Some(singleRow),
1.0,
false,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
None, None)(sqlContext)
val logicalRelation1 = LogicalRelation(relation1)
val relation2 = new JSONRelation(
Some(singleRow),
0.5,
false,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
None, None)(sqlContext)
val logicalRelation2 = LogicalRelation(relation2)
val relation3 = new JSONRelation(
Some(singleRow),
1.0,
false,
Some(StructType(StructField("b", IntegerType, true) :: Nil)),
None, None)(sqlContext)
val logicalRelation3 = LogicalRelation(relation3)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment