diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 5bb6f6c85d8012a4114048cf457f99dae1c50206..0f2dcdcacf0cad74797d578d6a7358e7cbb9ab72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -73,16 +73,18 @@ private[sql] object JsonRDD extends Logging { def makeStruct(values: Seq[Seq[String]], prefix: Seq[String]): StructType = { val (topLevel, structLike) = values.partition(_.size == 1) + val topLevelFields = topLevel.filter { name => resolved.get(prefix ++ name).get match { case ArrayType(elementType, _) => { def hasInnerStruct(t: DataType): Boolean = t match { - case s: StructType => false + case s: StructType => true case ArrayType(t1, _) => hasInnerStruct(t1) - case o => true + case o => false } - hasInnerStruct(elementType) + // Check if this array has inner struct. + !hasInnerStruct(elementType) } case struct: StructType => false case _ => true @@ -90,8 +92,11 @@ private[sql] object JsonRDD extends Logging { }.map { a => StructField(a.head, resolved.get(prefix ++ a).get, nullable = true) } + val topLevelFieldNameSet = topLevelFields.map(_.name) - val structFields: Seq[StructField] = structLike.groupBy(_(0)).map { + val structFields: Seq[StructField] = structLike.groupBy(_(0)).filter { + case (name, _) => !topLevelFieldNameSet.contains(name) + }.map { case (name, fields) => { val nestedFields = fields.map(_.tail) val structType = makeStruct(nestedFields, prefix :+ name) @@ -354,7 +359,8 @@ private[sql] object JsonRDD extends Logging { case (key, value) => if (count > 0) builder.append(",") count += 1 - builder.append(s"""\"${key}\":${toString(value)}""") + val stringValue = if (value.isInstanceOf[String]) s"""\"$value\"""" else toString(value) + builder.append(s"""\"${key}\":${stringValue}""") } builder.append("}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 362c7e1a52482d8c0093caa3e7964940a65f3b21..4b851d1b9615266353158dee7775dc185f8395a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -483,7 +483,8 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: StructField("array2", ArrayType(StructType( - StructField("field", LongType, true) :: Nil), false), true) :: Nil) + StructField("field", LongType, true) :: Nil), false), true) :: + StructField("array3", ArrayType(StringType, false), true) :: Nil) assert(expectedSchema === jsonSchemaRDD.schema) @@ -492,12 +493,14 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), Seq(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", - """{"field":str}"""), Seq(Seq(214748364700L), Seq(1))) :: Nil + """{"field":"str"}"""), Seq(Seq(214748364700L), Seq(1)), null) :: + Seq(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) :: + Seq(null, null, Seq("1", "2", "3")) :: Nil ) // Treat an element as a number. checkAnswer( - sql("select array1[0] + 1 from jsonTable"), + sql("select array1[0] + 1 from jsonTable where array1 is not null"), 2 ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index c204162dd2fc1f1be5a1b1a232c3c055a28f02d5..e5773a55875bc35e41fa01b9b3f25f8528a02e1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -57,7 +57,9 @@ object TestJsonData { val arrayElementTypeConflict = TestSQLContext.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], - "array2": [{"field":214748364700}, {"field":1}]}""" :: Nil) + "array2": [{"field":214748364700}, {"field":1}]}""" :: + """{"array3": [{"field":"str"}, {"field":1}]}""" :: + """{"array3": [1, 2, 3]}""" :: Nil) val missingFields = TestSQLContext.sparkContext.parallelize(