diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 12e19f955caa29dd0891a79673398f0ef05d3b5f..1bf57882ce0231661cbc74dab05f0684717c704d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -56,7 +56,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val paths = files.filterNot(_.getPath.getName startsWith "_").map(_.getPath.toString) val rdd = baseRdd(sparkSession, csvOptions, paths) val firstLine = findFirstLine(csvOptions, rdd) - val firstRow = new LineCsvReader(csvOptions).parseLine(firstLine) + val firstRow = new CsvReader(csvOptions).parseLine(firstLine) val header = if (csvOptions.headerFlag) { firstRow.zipWithIndex.map { case (value, index) => @@ -103,6 +103,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { val csvOptions = new CSVOptions(options) + val commentPrefix = csvOptions.comment.toString val headers = requiredSchema.fields.map(_.name) val broadcastedHadoopConf = @@ -118,7 +119,12 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { CSVRelation.dropHeaderLine(file, lineIterator, csvOptions) - val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers) + val csvParser = new CsvReader(csvOptions) + val tokenizedIterator = lineIterator.filter { line => + line.trim.nonEmpty && !line.startsWith(commentPrefix) + }.map { line => + csvParser.parseLine(line) + } val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions) var numMalformedRecords = 0 tokenizedIterator.flatMap { recordTokens => @@ -146,7 +152,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val rdd = baseRdd(sparkSession, options, inputPaths) // Make sure firstLine is materialized before sending to executors val firstLine = if (options.headerFlag) findFirstLine(options, rdd) else null - CSVRelation.univocityTokenizer(rdd, header, firstLine, options) + CSVRelation.univocityTokenizer(rdd, firstLine, options) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala index 21032625800b1a0e68a0e1ee1360e614b5ac168e..bf62732dd40482d5e2dc45107c0ff846e1425ffb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -27,11 +27,10 @@ import org.apache.spark.internal.Logging * Read and parse CSV-like input * * @param params Parameters object - * @param headers headers for the columns */ -private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) { +private[sql] class CsvReader(params: CSVOptions) { - protected lazy val parser: CsvParser = { + private val parser: CsvParser = { val settings = new CsvParserSettings() val format = settings.getFormat format.setDelimiter(params.delimiter) @@ -47,10 +46,17 @@ private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) settings.setNullValue(params.nullValue) settings.setMaxCharsPerColumn(params.maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) - if (headers != null) settings.setHeaders(headers: _*) new CsvParser(settings) } + + /** + * parse a line + * + * @param line a String with no newline at the end + * @return array of strings where each string is a field in the CSV record + */ + def parseLine(line: String): Array[String] = parser.parseLine(line) } /** @@ -97,157 +103,3 @@ private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) exten writer.close() } } - -/** - * Parser for parsing a line at a time. Not efficient for bulk data. - * - * @param params Parameters object - */ -private[sql] class LineCsvReader(params: CSVOptions) - extends CsvReader(params, null) { - /** - * parse a line - * - * @param line a String with no newline at the end - * @return array of strings where each string is a field in the CSV record - */ - def parseLine(line: String): Array[String] = { - parser.beginParsing(new StringReader(line)) - val parsed = parser.parseNext() - parser.stopParsing() - parsed - } -} - -/** - * Parser for parsing lines in bulk. Use this when efficiency is desired. - * - * @param iter iterator over lines in the file - * @param params Parameters object - * @param headers headers for the columns - */ -private[sql] class BulkCsvReader( - iter: Iterator[String], - params: CSVOptions, - headers: Seq[String]) - extends CsvReader(params, headers) with Iterator[Array[String]] { - - private val reader = new StringIteratorReader(iter) - parser.beginParsing(reader) - private var nextRecord = parser.parseNext() - - /** - * get the next parsed line. - * @return array of strings where each string is a field in the CSV record - */ - override def next(): Array[String] = { - val curRecord = nextRecord - if(curRecord != null) { - nextRecord = parser.parseNext() - } else { - throw new NoSuchElementException("next record is null") - } - curRecord - } - - override def hasNext: Boolean = nextRecord != null - -} - -/** - * A Reader that "reads" from a sequence of lines. Spark's textFile method removes newlines at - * end of each line Univocity parser requires a Reader that provides access to the data to be - * parsed and needs the newlines to be present - * @param iter iterator over RDD[String] - */ -private class StringIteratorReader(val iter: Iterator[String]) extends java.io.Reader { - - private var next: Long = 0 - private var length: Long = 0 // length of input so far - private var start: Long = 0 - private var str: String = null // current string from iter - - /** - * fetch next string from iter, if done with current one - * pretend there is a new line at the end of every string we get from from iter - */ - private def refill(): Unit = { - if (length == next) { - if (iter.hasNext) { - str = iter.next() - start = length - length += (str.length + 1) // allowance for newline removed by SparkContext.textFile() - } else { - str = null - } - } - } - - /** - * read the next character, if at end of string pretend there is a new line - */ - override def read(): Int = { - refill() - if (next >= length) { - -1 - } else { - val cur = next - start - next += 1 - if (cur == str.length) '\n' else str.charAt(cur.toInt) - } - } - - /** - * read from str into cbuf - */ - override def read(cbuf: Array[Char], off: Int, len: Int): Int = { - refill() - var n = 0 - if ((off < 0) || (off > cbuf.length) || (len < 0) || - ((off + len) > cbuf.length) || ((off + len) < 0)) { - throw new IndexOutOfBoundsException() - } else if (len == 0) { - n = 0 - } else { - if (next >= length) { // end of input - n = -1 - } else { - n = Math.min(length - next, len).toInt // lesser of amount of input available or buf size - if (n == length - next) { - str.getChars((next - start).toInt, (next - start + n - 1).toInt, cbuf, off) - cbuf(off + n - 1) = '\n' - } else { - str.getChars((next - start).toInt, (next - start + n).toInt, cbuf, off) - } - next += n - if (n < len) { - val m = read(cbuf, off + n, len - n) // have more space, fetch more input from iter - if(m != -1) n += m - } - } - } - - n - } - - override def skip(ns: Long): Long = { - throw new IllegalArgumentException("Skip not implemented") - } - - override def ready: Boolean = { - refill() - true - } - - override def markSupported: Boolean = false - - override def mark(readAheadLimit: Int): Unit = { - throw new IllegalArgumentException("Mark not implemented") - } - - override def reset(): Unit = { - throw new IllegalArgumentException("Mark and hence reset not implemented") - } - - override def close(): Unit = { } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index e8c0134d38803e3445cd8ac56b8eb7b5449023d0..c6ba424d8687595c9b3675c4ab44523e4a04a7c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -38,15 +38,24 @@ object CSVRelation extends Logging { def univocityTokenizer( file: RDD[String], - header: Seq[String], firstLine: String, params: CSVOptions): RDD[Array[String]] = { // If header is set, make sure firstLine is materialized before sending to executors. + val commentPrefix = params.comment.toString file.mapPartitions { iter => - new BulkCsvReader( - if (params.headerFlag) iter.filterNot(_ == firstLine) else iter, - params, - headers = header) + val parser = new CsvReader(params) + val filteredIter = iter.filter { line => + line.trim.nonEmpty && !line.startsWith(commentPrefix) + } + if (params.headerFlag) { + filteredIter.filterNot(_ == firstLine).map { item => + parser.parseLine(item) + } + } else { + filteredIter.map { item => + parser.parseLine(item) + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala deleted file mode 100644 index aaeecef5f37fcc521eb8340764273ac77d541d46..0000000000000000000000000000000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.csv - -import org.apache.spark.SparkFunSuite - -/** - * test cases for StringIteratorReader - */ -class CSVParserSuite extends SparkFunSuite { - - private def readAll(iter: Iterator[String]) = { - val reader = new StringIteratorReader(iter) - var c: Int = -1 - val read = new scala.collection.mutable.StringBuilder() - do { - c = reader.read() - read.append(c.toChar) - } while (c != -1) - - read.dropRight(1).toString - } - - private def readBufAll(iter: Iterator[String], bufSize: Int) = { - val reader = new StringIteratorReader(iter) - val cbuf = new Array[Char](bufSize) - val read = new scala.collection.mutable.StringBuilder() - - var done = false - do { // read all input one cbuf at a time - var numRead = 0 - var n = 0 - do { // try to fill cbuf - var off = 0 - var len = cbuf.length - n = reader.read(cbuf, off, len) - - if (n != -1) { - off += n - len -= n - } - - assert(len >= 0 && len <= cbuf.length) - assert(off >= 0 && off <= cbuf.length) - read.appendAll(cbuf.take(n)) - } while (n > 0) - if(n != -1) { - numRead += n - } else { - done = true - } - } while (!done) - - read.toString - } - - test("Hygiene") { - val reader = new StringIteratorReader(List("").toIterator) - assert(reader.ready === true) - assert(reader.markSupported === false) - intercept[IllegalArgumentException] { reader.skip(1) } - intercept[IllegalArgumentException] { reader.mark(1) } - intercept[IllegalArgumentException] { reader.reset() } - } - - test("Regular case") { - val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"") - val read = readAll(input.toIterator) - assert(read === input.mkString("\n") ++ "\n") - } - - test("Empty iter") { - val input = List[String]() - val read = readAll(input.toIterator) - assert(read === "") - } - - test("Embedded new line") { - val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"") - val read = readAll(input.toIterator) - assert(read === input.mkString("\n") ++ "\n") - } - - test("Buffer Regular case") { - val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"") - val output = input.mkString("\n") ++ "\n" - for(i <- 1 to output.length + 5) { - val read = readBufAll(input.toIterator, i) - assert(read === output) - } - } - - test("Buffer Empty iter") { - val input = List[String]() - val output = "" - for(i <- 1 to output.length + 5) { - val read = readBufAll(input.toIterator, 1) - assert(read === "") - } - } - - test("Buffer Embedded new line") { - val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"") - val output = input.mkString("\n") ++ "\n" - for(i <- 1 to output.length + 5) { - val read = readBufAll(input.toIterator, 1) - assert(read === output) - } - } -}