diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 995095969d7af9251b05c2c71dee7e007ccd3321..9b80c0fc87c9384506f868fee3c3add22e80c886 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -58,7 +58,10 @@ class JacksonParser( private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length)) private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach(idx => require(schema(idx).dataType == StringType)) + corruptFieldIndex.foreach { corrFieldIndex => + require(schema(corrFieldIndex).dataType == StringType) + require(schema(corrFieldIndex).nullable) + } @transient private[this] var isWarningPrinted: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 4c1341ed5da608b06c313866b76405483fb58473..2be22761e8dbc9778a6f2d7f2366235d9293b199 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.JsonInferSchema -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String /** @@ -365,6 +365,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { createParser) } + // Check a field requirement for corrupt records here to throw an exception in a driver side + schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = schema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + val parsed = jsonDataset.rdd.mapPartitions { iter => val parser = new JacksonParser(schema, parsedOptions) iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 2cbf4ea7beaca1e31126f47cd79e2848bc2daade..902fee5a7e3f73e6980c0bb56b60087f4ec32fe7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -22,13 +22,13 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { @@ -102,6 +102,15 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) + // Check a field requirement for corrupt records here to throw an exception in a driver side + dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = dataSchema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + (file: PartitionedFile) => { val parser = new JacksonParser(requiredSchema, parsedOptions) JsonDataSource(parsedOptions).readFile( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 05aa2ab2ce2d0b86194849e254fa5e9878b65c6b..0e01be2410409388c2f48fc001ed6298feb3ae03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1944,4 +1944,35 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode")) } } + + test("Throw an exception if a `columnNameOfCorruptRecord` field violates requirements") { + val columnNameOfCorruptRecord = "_unparsed" + val schema = StructType( + StructField(columnNameOfCorruptRecord, IntegerType, true) :: + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + val errMsg = intercept[AnalysisException] { + spark.read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schema) + .json(corruptRecords) + }.getMessage + assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + + withTempPath { dir => + val path = dir.getCanonicalPath + corruptRecords.toDF("value").write.text(path) + val errMsg = intercept[AnalysisException] { + spark.read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schema) + .json(path) + .collect + }.getMessage + assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + } + } }