Skip to content
Snippets Groups Projects
Commit 45512902 authored by hyukjinkwon's avatar hyukjinkwon Committed by Wenchen Fan
Browse files

[SPARK-15463][SQL] Add an API to load DataFrame from Dataset[String] storing CSV

## What changes were proposed in this pull request?

This PR proposes to add an API that loads `DataFrame` from `Dataset[String]` storing csv.

It allows pre-processing before loading into CSV, which means allowing a lot of workarounds for many narrow cases, for example, as below:

- Case 1 - pre-processing

  ```scala
  val df = spark.read.text("...")
  // Pre-processing with this.
  spark.read.csv(df.as[String])
  ```

- Case 2 - use other input formats

  ```scala
  val rdd = spark.sparkContext.newAPIHadoopFile("/file.csv.lzo",
    classOf[com.hadoop.mapreduce.LzoTextInputFormat],
    classOf[org.apache.hadoop.io.LongWritable],
    classOf[org.apache.hadoop.io.Text])
  val stringRdd = rdd.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength))

  spark.read.csv(stringRdd.toDS)
  ```

## How was this patch tested?

Added tests in `CSVSuite` and build with Scala 2.10.

```
./dev/change-scala-version.sh 2.10
./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package
```

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #16854 from HyukjinKwon/SPARK-15463.
parent 6570cfd7
No related branches found
No related tags found
No related merge requests found
......@@ -29,6 +29,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.csv._
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
......@@ -368,14 +369,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
createParser)
}
// Check a field requirement for corrupt records here to throw an exception in a driver side
schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
val f = schema(corruptFieldIndex)
if (f.dataType != StringType || !f.nullable) {
throw new AnalysisException(
"The field for corrupt records must be string type and nullable")
}
}
verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
val parsed = jsonDataset.rdd.mapPartitions { iter =>
val parser = new JacksonParser(schema, parsedOptions)
......@@ -398,6 +392,51 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
csv(Seq(path): _*)
}
/**
* Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`.
*
* If the schema is not specified using `schema` function and `inferSchema` option is enabled,
* this function goes through the input once to determine the input schema.
*
* If the schema is not specified using `schema` function and `inferSchema` option is disabled,
* it determines the columns as string types and it reads only the first line to determine the
* names and the number of fields.
*
* @param csvDataset input Dataset with one CSV row per record
* @since 2.2.0
*/
def csv(csvDataset: Dataset[String]): DataFrame = {
val parsedOptions: CSVOptions = new CSVOptions(
extraOptions.toMap,
sparkSession.sessionState.conf.sessionLocalTimeZone)
val filteredLines: Dataset[String] =
CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions)
val maybeFirstLine: Option[String] = filteredLines.take(1).headOption
val schema = userSpecifiedSchema.getOrElse {
TextInputCSVDataSource.inferFromDataset(
sparkSession,
csvDataset,
maybeFirstLine,
parsedOptions)
}
verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
}.getOrElse(filteredLines.rdd)
val parsed = linesWithoutHeader.mapPartitions { iter =>
val parser = new UnivocityParser(schema, parsedOptions)
iter.flatMap(line => parser.parse(line))
}
Dataset.ofRows(
sparkSession,
LogicalRDD(schema.toAttributes, parsed)(sparkSession))
}
/**
* Loads a CSV file and returns the result as a `DataFrame`.
*
......@@ -604,6 +643,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
}
/**
* A convenient function for schema validation in datasources supporting
* `columnNameOfCorruptRecord` as an option.
*/
private def verifyColumnNameOfCorruptRecord(
schema: StructType,
columnNameOfCorruptRecord: String): Unit = {
schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
val f = schema(corruptFieldIndex)
if (f.dataType != StringType || !f.nullable) {
throw new AnalysisException(
"The field for corrupt records must be string type and nullable")
}
}
}
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
......
......@@ -17,12 +17,11 @@
package org.apache.spark.sql.execution.datasources.csv
import java.io.InputStream
import java.nio.charset.{Charset, StandardCharsets}
import com.univocity.parsers.csv.{CsvParser, CsvParserSettings}
import com.univocity.parsers.csv.CsvParser
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.Job
......@@ -134,23 +133,33 @@ object TextInputCSVDataSource extends CSVDataSource {
inputPaths: Seq[FileStatus],
parsedOptions: CSVOptions): Option[StructType] = {
val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption match {
case Some(firstLine) =>
val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
val tokenRDD = csv.rdd.mapPartitions { iter =>
val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
val linesWithoutHeader =
CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
val parser = new CsvParser(parsedOptions.asParserSettings)
linesWithoutHeader.map(parser.parseLine)
}
Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
case None =>
// If the first line could not be read, just return the empty schema.
Some(StructType(Nil))
}
val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
Some(inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions))
}
/**
* Infers the schema from `Dataset` that stores CSV string records.
*/
def inferFromDataset(
sparkSession: SparkSession,
csv: Dataset[String],
maybeFirstLine: Option[String],
parsedOptions: CSVOptions): StructType = maybeFirstLine match {
case Some(firstLine) =>
val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
val tokenRDD = csv.rdd.mapPartitions { iter =>
val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
val linesWithoutHeader =
CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
val parser = new CsvParser(parsedOptions.asParserSettings)
linesWithoutHeader.map(parser.parseLine)
}
CSVInferSchema.infer(tokenRDD, header, parsedOptions)
case None =>
// If the first line could not be read, just return the empty schema.
StructType(Nil)
}
private def createBaseDataset(
......
......@@ -26,7 +26,7 @@ import org.apache.commons.lang3.time.FastDateFormat
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes}
private[csv] class CSVOptions(
class CSVOptions(
@transient private val parameters: CaseInsensitiveMap[String],
defaultTimeZoneId: String,
defaultColumnNameOfCorruptRecord: String)
......
......@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
private[csv] class UnivocityParser(
class UnivocityParser(
schema: StructType,
requiredSchema: StructType,
private val options: CSVOptions) extends Logging {
......
......@@ -129,6 +129,22 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
verifyCars(cars, withHeader = true, checkTypes = true)
}
test("simple csv test with string dataset") {
val csvDataset = spark.read.text(testFile(carsFile)).as[String]
val cars = spark.read
.option("header", "true")
.option("inferSchema", "true")
.csv(csvDataset)
verifyCars(cars, withHeader = true, checkTypes = true)
val carsWithoutHeader = spark.read
.option("header", "false")
.csv(csvDataset)
verifyCars(carsWithoutHeader, withHeader = false, checkTypes = false)
}
test("test inferring booleans") {
val result = spark.read
.format("csv")
......@@ -1088,4 +1104,15 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
checkAnswer(df, spark.emptyDataFrame)
}
}
test("Empty string dataset produces empty dataframe and keep user-defined schema") {
val df1 = spark.read.csv(spark.emptyDataset[String])
assert(df1.schema === spark.emptyDataFrame.schema)
checkAnswer(df1, spark.emptyDataFrame)
val schema = StructType(StructField("a", StringType) :: Nil)
val df2 = spark.read.schema(schema).csv(spark.emptyDataset[String])
assert(df2.schema === schema)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment