diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index d87c4945c81948817b6ff7639906bd321abde005..eeabfdd85791619356527164714d589be03f31e4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -346,16 +346,20 @@ private[hive] trait HiveInspectors { case soi: StandardStructObjectInspector => val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) (o: Any) => { - val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { - (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + if (o != null) { + val struct = soi.create() + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + } + struct + } else { + null } - struct } case loi: ListObjectInspector => val wrapper = wrapperFor(loi.getListElementObjectInspector) - (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) + (o: Any) => if (o != null) seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) else null case moi: MapObjectInspector => // The Predef.Map is scala.collection.immutable.Map. @@ -364,9 +368,15 @@ private[hive] trait HiveInspectors { val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) - (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => - keyWrapper(key) -> valueWrapper(value) - }) + (o: Any) => { + if (o != null) { + mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => + keyWrapper(key) -> valueWrapper(value) + }) + } else { + null + } + } case _ => identity[Any] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c1c3683f84ab279e5c687506c722b22e373bcd2d..d41eb9e870bf087955b276b29b7a2c5627ed77e3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.Row import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.util.Utils +import org.apache.spark.sql.types._ case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -214,4 +214,39 @@ class SQLQuerySuite extends QueryTest { Seq.empty[Row]) } } + + test("SPARK-5284 Insert into Hive throws NPE when a inner complex type field has a null value") { + val schema = StructType( + StructField("s", + StructType( + StructField("innerStruct", StructType(StructField("s1", StringType, true) :: Nil)) :: + StructField("innerArray", ArrayType(IntegerType), true) :: + StructField("innerMap", MapType(StringType, IntegerType)) :: Nil), true) :: Nil) + val row = Row(Row(null, null, null)) + + val rowRdd = sparkContext.parallelize(row :: Nil) + + applySchema(rowRdd, schema).registerTempTable("testTable") + + sql( + """CREATE TABLE nullValuesInInnerComplexTypes + | (s struct<innerStruct: struct<s1:string>, + | innerArray:array<int>, + | innerMap: map<string, int>>) + """.stripMargin).collect + + sql( + """ + |INSERT OVERWRITE TABLE nullValuesInInnerComplexTypes + |SELECT * FROM testTable + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM nullValuesInInnerComplexTypes"), + Seq(Seq(Seq(null, null, null))) + ) + + sql("DROP TABLE nullValuesInInnerComplexTypes") + dropTempTable("testTable") + } }