diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 72fd184d580cce49bdc366c46905d649053fe1a5..89506ca02f27314eb80ec038225d215c87b564a8 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -392,6 +392,10 @@ class DataFrameReader(ReaderUtils): :param maxCharsPerColumn: defines the maximum number of characters allowed for any given value being read. If None is set, it uses the default value, ``1000000``. + :param maxMalformedLogPerPartition: sets the maximum number of malformed rows Spark will + log for each partition. Malformed records beyond this + number will be ignored. If None is set, it + uses the default value, ``10``. :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 841503b260c35c8742aed6d2896a4035a6149c6b..35ba9c50790e479bd47dcdaa356701868130d381 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -382,6 +382,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * a record can have.</li> * <li>`maxCharsPerColumn` (default `1000000`): defines the maximum number of characters allowed * for any given value being read.</li> + * <li>`maxMalformedLogPerPartition` (default `10`): sets the maximum number of malformed rows + * Spark will log for each partition. Malformed records beyond this number will be ignored.</li> * <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing.</li> * <ul> diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index be52de8e4014042eebb5a35cad73bdcd64d78072..12e19f955caa29dd0891a79673398f0ef05d3b5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -120,7 +120,14 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers) val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions) - tokenizedIterator.flatMap(parser(_).toSeq) + var numMalformedRecords = 0 + tokenizedIterator.flatMap { recordTokens => + val row = parser(recordTokens, numMalformedRecords) + if (row.isEmpty) { + numMalformedRecords += 1 + } + row + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 9f4ce8358b045eb51e91e1ef51d71e6148a195f0..581eda7e09a3e181a046bd88da83cb9dbf5c49da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -113,6 +113,8 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str val escapeQuotes = getBool("escapeQuotes", true) + val maxMalformedLogPerPartition = getInt("maxMalformedLogPerPartition", 10) + val inputBufferSize = 128 val isCommentSet = this.comment != '\u0000' diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index d72c8b9ac2e7cccb40d1e20cfb2611e904c70eee..083ac3350ef023d3cc53984eef2dba64fc208e06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -50,10 +50,19 @@ object CSVRelation extends Logging { } } + /** + * Returns a function that parses a single CSV record (in the form of an array of strings in which + * each element represents a column) and turns it into either one resulting row or no row (if the + * the record is malformed). + * + * The 2nd argument in the returned function represents the total number of malformed rows + * observed so far. + */ + // This is pretty convoluted and we should probably rewrite the entire CSV parsing soon. def csvParser( schema: StructType, requiredColumns: Array[String], - params: CSVOptions): Array[String] => Option[InternalRow] = { + params: CSVOptions): (Array[String], Int) => Option[InternalRow] = { val schemaFields = schema.fields val requiredFields = StructType(requiredColumns.map(schema(_))).fields val safeRequiredFields = if (params.dropMalformed) { @@ -72,9 +81,16 @@ object CSVRelation extends Logging { val requiredSize = requiredFields.length val row = new GenericMutableRow(requiredSize) - (tokens: Array[String]) => { + (tokens: Array[String], numMalformedRows) => { if (params.dropMalformed && schemaFields.length != tokens.length) { - logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + if (numMalformedRows < params.maxMalformedLogPerPartition) { + logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + } + if (numMalformedRows == params.maxMalformedLogPerPartition - 1) { + logWarning( + s"More than ${params.maxMalformedLogPerPartition} malformed records have been " + + "found on this partition. Malformed records from now on will not be logged.") + } None } else if (params.failFast && schemaFields.length != tokens.length) { throw new RuntimeException(s"Malformed line in FAILFAST mode: " + @@ -109,23 +125,21 @@ object CSVRelation extends Logging { Some(row) } catch { case NonFatal(e) if params.dropMalformed => - logWarning("Parse exception. " + - s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + if (numMalformedRows < params.maxMalformedLogPerPartition) { + logWarning("Parse exception. " + + s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + } + if (numMalformedRows == params.maxMalformedLogPerPartition - 1) { + logWarning( + s"More than ${params.maxMalformedLogPerPartition} malformed records have been " + + "found on this partition. Malformed records from now on will not be logged.") + } None } } } } - def parseCsv( - tokenizedRDD: RDD[Array[String]], - schema: StructType, - requiredColumns: Array[String], - options: CSVOptions): RDD[InternalRow] = { - val parser = csvParser(schema, requiredColumns, options) - tokenizedRDD.flatMap(parser(_).toSeq) - } - // Skips the header line of each file if the `header` option is set to true. def dropHeaderLine( file: PartitionedFile, lines: Iterator[String], csvOptions: CSVOptions): Unit = {