diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 73e6abc6dad3790af2ff1229d665c9dde57fc5ed..47567032b01958de314463d36aa8bcd871a50d0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -133,20 +133,24 @@ object TextInputCSVDataSource extends CSVDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): Option[StructType] = { - val csv: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).first() - val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) - val linesWithoutHeader = - CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) - val parser = new CsvParser(parsedOptions.asParserSettings) - linesWithoutHeader.map(parser.parseLine) + val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) + CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption match { + case Some(firstLine) => + val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + case None => + // If the first line could not be read, just return the empty schema. + Some(StructType(Nil)) } - - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) } private def createBaseDataset( @@ -190,28 +194,28 @@ object WholeFileCSVDataSource extends CSVDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): Option[StructType] = { - val csv: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions) - val maybeFirstRow: Option[Array[String]] = csv.flatMap { lines => + val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions) + csv.flatMap { lines => UnivocityParser.tokenizeStream( CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), - false, + shouldDropHeader = false, new CsvParser(parsedOptions.asParserSettings)) - }.take(1).headOption - - if (maybeFirstRow.isDefined) { - val firstRow = maybeFirstRow.get - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.flatMap { lines => - UnivocityParser.tokenizeStream( - CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), - parsedOptions.headerFlag, - new CsvParser(parsedOptions.asParserSettings)) - } - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) - } else { - // If the first row could not be read, just return the empty schema. - Some(StructType(Nil)) + }.take(1).headOption match { + case Some(firstRow) => + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.flatMap { lines => + UnivocityParser.tokenizeStream( + CodecStreams.createInputStreamWithCloseResource( + lines.getConfiguration, + lines.getPath()), + parsedOptions.headerFlag, + new CsvParser(parsedOptions.asParserSettings)) + } + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + case None => + // If the first row could not be read, just return the empty schema. + Some(StructType(Nil)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 56071803f685fa1a0f76ba0748c9e008c28814b2..eaedede349134281650389cf9414d7ef45985ae4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1077,14 +1077,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } - test("Empty file produces empty dataframe with empty schema - wholeFile option") { - withTempPath { path => - path.createNewFile() - + test("Empty file produces empty dataframe with empty schema") { + Seq(false, true).foreach { wholeFile => val df = spark.read.format("csv") .option("header", true) - .option("wholeFile", true) - .load(path.getAbsolutePath) + .option("wholeFile", wholeFile) + .load(testFile(emptyFile)) assert(df.schema === spark.emptyDataFrame.schema) checkAnswer(df, spark.emptyDataFrame)