diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 41470ae6aae19d5bfb68c69bff2990fe43d551a0..a5e38e25b1ec59a8462f62e24a5b5ab86f5222c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,6 +29,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.JsonInferSchema @@ -368,14 +369,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { createParser) } - // Check a field requirement for corrupt records here to throw an exception in a driver side - schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = schema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { - throw new AnalysisException( - "The field for corrupt records must be string type and nullable") - } - } + verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) val parsed = jsonDataset.rdd.mapPartitions { iter => val parser = new JacksonParser(schema, parsedOptions) @@ -398,6 +392,51 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { csv(Seq(path): _*) } + /** + * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`. + * + * If the schema is not specified using `schema` function and `inferSchema` option is enabled, + * this function goes through the input once to determine the input schema. + * + * If the schema is not specified using `schema` function and `inferSchema` option is disabled, + * it determines the columns as string types and it reads only the first line to determine the + * names and the number of fields. + * + * @param csvDataset input Dataset with one CSV row per record + * @since 2.2.0 + */ + def csv(csvDataset: Dataset[String]): DataFrame = { + val parsedOptions: CSVOptions = new CSVOptions( + extraOptions.toMap, + sparkSession.sessionState.conf.sessionLocalTimeZone) + val filteredLines: Dataset[String] = + CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions) + val maybeFirstLine: Option[String] = filteredLines.take(1).headOption + + val schema = userSpecifiedSchema.getOrElse { + TextInputCSVDataSource.inferFromDataset( + sparkSession, + csvDataset, + maybeFirstLine, + parsedOptions) + } + + verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + + val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => + filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) + }.getOrElse(filteredLines.rdd) + + val parsed = linesWithoutHeader.mapPartitions { iter => + val parser = new UnivocityParser(schema, parsedOptions) + iter.flatMap(line => parser.parse(line)) + } + + Dataset.ofRows( + sparkSession, + LogicalRDD(schema.toAttributes, parsed)(sparkSession)) + } + /** * Loads a CSV file and returns the result as a `DataFrame`. * @@ -604,6 +643,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } } + /** + * A convenient function for schema validation in datasources supporting + * `columnNameOfCorruptRecord` as an option. + */ + private def verifyColumnNameOfCorruptRecord( + schema: StructType, + columnNameOfCorruptRecord: String): Unit = { + schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = schema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + } + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 47567032b01958de314463d36aa8bcd871a50d0d..35ff924f27ce5745711356a3665e363e7c2a0ba6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.datasources.csv -import java.io.InputStream import java.nio.charset.{Charset, StandardCharsets} -import com.univocity.parsers.csv.{CsvParser, CsvParserSettings} +import com.univocity.parsers.csv.CsvParser import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.Job @@ -134,23 +133,33 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): Option[StructType] = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption match { - case Some(firstLine) => - val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) - val linesWithoutHeader = - CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) - val parser = new CsvParser(parsedOptions.asParserSettings) - linesWithoutHeader.map(parser.parseLine) - } - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) - case None => - // If the first line could not be read, just return the empty schema. - Some(StructType(Nil)) - } + val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption + Some(inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)) + } + + /** + * Infers the schema from `Dataset` that stores CSV string records. + */ + def inferFromDataset( + sparkSession: SparkSession, + csv: Dataset[String], + maybeFirstLine: Option[String], + parsedOptions: CSVOptions): StructType = maybeFirstLine match { + case Some(firstLine) => + val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + CSVInferSchema.infer(tokenRDD, header, parsedOptions) + case None => + // If the first line could not be read, just return the empty schema. + StructType(Nil) } private def createBaseDataset( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 50503385ad6d151f8cc7299946bde1dff50e7552..0b1e5dac2da66c22e21c001cc22cca47ebfe7a78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -26,7 +26,7 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} -private[csv] class CSVOptions( +class CSVOptions( @transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 3b3b87e4354d6f4f06411590f16958bd05163c1d..e42ea3fa391f5c39606a18d70460a953575e14b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -private[csv] class UnivocityParser( +class UnivocityParser( schema: StructType, requiredSchema: StructType, private val options: CSVOptions) extends Logging { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index eaedede349134281650389cf9414d7ef45985ae4..4435e4df38ef66c0b6a9a30f945680cde17f80b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -129,6 +129,22 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = true, checkTypes = true) } + test("simple csv test with string dataset") { + val csvDataset = spark.read.text(testFile(carsFile)).as[String] + val cars = spark.read + .option("header", "true") + .option("inferSchema", "true") + .csv(csvDataset) + + verifyCars(cars, withHeader = true, checkTypes = true) + + val carsWithoutHeader = spark.read + .option("header", "false") + .csv(csvDataset) + + verifyCars(carsWithoutHeader, withHeader = false, checkTypes = false) + } + test("test inferring booleans") { val result = spark.read .format("csv") @@ -1088,4 +1104,15 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(df, spark.emptyDataFrame) } } + + test("Empty string dataset produces empty dataframe and keep user-defined schema") { + val df1 = spark.read.csv(spark.emptyDataset[String]) + assert(df1.schema === spark.emptyDataFrame.schema) + checkAnswer(df1, spark.emptyDataFrame) + + val schema = StructType(StructField("a", StringType) :: Nil) + val df2 = spark.read.schema(schema).csv(spark.emptyDataset[String]) + assert(df2.schema === schema) + } + }