From 455129020ca7f6a162f6f2486a87cc43512cfd2c Mon Sep 17 00:00:00 2001
From: hyukjinkwon <gurwls223@gmail.com>
Date: Wed, 8 Mar 2017 13:43:09 -0800
Subject: [PATCH] [SPARK-15463][SQL] Add an API to load DataFrame from
 Dataset[String] storing CSV

## What changes were proposed in this pull request?

This PR proposes to add an API that loads `DataFrame` from `Dataset[String]` storing csv.

It allows pre-processing before loading into CSV, which means allowing a lot of workarounds for many narrow cases, for example, as below:

- Case 1 - pre-processing

  ```scala
  val df = spark.read.text("...")
  // Pre-processing with this.
  spark.read.csv(df.as[String])
  ```

- Case 2 - use other input formats

  ```scala
  val rdd = spark.sparkContext.newAPIHadoopFile("/file.csv.lzo",
    classOf[com.hadoop.mapreduce.LzoTextInputFormat],
    classOf[org.apache.hadoop.io.LongWritable],
    classOf[org.apache.hadoop.io.Text])
  val stringRdd = rdd.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength))

  spark.read.csv(stringRdd.toDS)
  ```

## How was this patch tested?

Added tests in `CSVSuite` and build with Scala 2.10.

```
./dev/change-scala-version.sh 2.10
./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package
```

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #16854 from HyukjinKwon/SPARK-15463.
---
 .../apache/spark/sql/DataFrameReader.scala    | 71 ++++++++++++++++---
 .../datasources/csv/CSVDataSource.scala       | 49 +++++++------
 .../datasources/csv/CSVOptions.scala          |  2 +-
 .../datasources/csv/UnivocityParser.scala     |  2 +-
 .../execution/datasources/csv/CSVSuite.scala  | 27 +++++++
 5 files changed, 121 insertions(+), 30 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 41470ae6aa..a5e38e25b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -29,6 +29,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
 import org.apache.spark.sql.execution.LogicalRDD
 import org.apache.spark.sql.execution.command.DDLUtils
+import org.apache.spark.sql.execution.datasources.csv._
 import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.execution.datasources.jdbc._
 import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
@@ -368,14 +369,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
         createParser)
     }
 
-    // Check a field requirement for corrupt records here to throw an exception in a driver side
-    schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
-      val f = schema(corruptFieldIndex)
-      if (f.dataType != StringType || !f.nullable) {
-        throw new AnalysisException(
-          "The field for corrupt records must be string type and nullable")
-      }
-    }
+    verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
 
     val parsed = jsonDataset.rdd.mapPartitions { iter =>
       val parser = new JacksonParser(schema, parsedOptions)
@@ -398,6 +392,51 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
     csv(Seq(path): _*)
   }
 
+  /**
+   * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`.
+   *
+   * If the schema is not specified using `schema` function and `inferSchema` option is enabled,
+   * this function goes through the input once to determine the input schema.
+   *
+   * If the schema is not specified using `schema` function and `inferSchema` option is disabled,
+   * it determines the columns as string types and it reads only the first line to determine the
+   * names and the number of fields.
+   *
+   * @param csvDataset input Dataset with one CSV row per record
+   * @since 2.2.0
+   */
+  def csv(csvDataset: Dataset[String]): DataFrame = {
+    val parsedOptions: CSVOptions = new CSVOptions(
+      extraOptions.toMap,
+      sparkSession.sessionState.conf.sessionLocalTimeZone)
+    val filteredLines: Dataset[String] =
+      CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions)
+    val maybeFirstLine: Option[String] = filteredLines.take(1).headOption
+
+    val schema = userSpecifiedSchema.getOrElse {
+      TextInputCSVDataSource.inferFromDataset(
+        sparkSession,
+        csvDataset,
+        maybeFirstLine,
+        parsedOptions)
+    }
+
+    verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
+
+    val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
+      filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
+    }.getOrElse(filteredLines.rdd)
+
+    val parsed = linesWithoutHeader.mapPartitions { iter =>
+      val parser = new UnivocityParser(schema, parsedOptions)
+      iter.flatMap(line => parser.parse(line))
+    }
+
+    Dataset.ofRows(
+      sparkSession,
+      LogicalRDD(schema.toAttributes, parsed)(sparkSession))
+  }
+
   /**
    * Loads a CSV file and returns the result as a `DataFrame`.
    *
@@ -604,6 +643,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
     }
   }
 
+  /**
+   * A convenient function for schema validation in datasources supporting
+   * `columnNameOfCorruptRecord` as an option.
+   */
+  private def verifyColumnNameOfCorruptRecord(
+      schema: StructType,
+      columnNameOfCorruptRecord: String): Unit = {
+    schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
+      val f = schema(corruptFieldIndex)
+      if (f.dataType != StringType || !f.nullable) {
+        throw new AnalysisException(
+          "The field for corrupt records must be string type and nullable")
+      }
+    }
+  }
+
   ///////////////////////////////////////////////////////////////////////////////////////
   // Builder pattern config options
   ///////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index 47567032b0..35ff924f27 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -17,12 +17,11 @@
 
 package org.apache.spark.sql.execution.datasources.csv
 
-import java.io.InputStream
 import java.nio.charset.{Charset, StandardCharsets}
 
-import com.univocity.parsers.csv.{CsvParser, CsvParserSettings}
+import com.univocity.parsers.csv.CsvParser
 import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.fs.FileStatus
 import org.apache.hadoop.io.{LongWritable, Text}
 import org.apache.hadoop.mapred.TextInputFormat
 import org.apache.hadoop.mapreduce.Job
@@ -134,23 +133,33 @@ object TextInputCSVDataSource extends CSVDataSource {
       inputPaths: Seq[FileStatus],
       parsedOptions: CSVOptions): Option[StructType] = {
     val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
-    CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption match {
-      case Some(firstLine) =>
-        val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
-        val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
-        val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
-        val tokenRDD = csv.rdd.mapPartitions { iter =>
-          val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
-          val linesWithoutHeader =
-            CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
-          val parser = new CsvParser(parsedOptions.asParserSettings)
-          linesWithoutHeader.map(parser.parseLine)
-        }
-        Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
-      case None =>
-        // If the first line could not be read, just return the empty schema.
-        Some(StructType(Nil))
-    }
+    val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
+    Some(inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions))
+  }
+
+  /**
+   * Infers the schema from `Dataset` that stores CSV string records.
+   */
+  def inferFromDataset(
+      sparkSession: SparkSession,
+      csv: Dataset[String],
+      maybeFirstLine: Option[String],
+      parsedOptions: CSVOptions): StructType = maybeFirstLine match {
+    case Some(firstLine) =>
+      val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
+      val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
+      val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
+      val tokenRDD = csv.rdd.mapPartitions { iter =>
+        val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
+        val linesWithoutHeader =
+          CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
+        val parser = new CsvParser(parsedOptions.asParserSettings)
+        linesWithoutHeader.map(parser.parseLine)
+      }
+      CSVInferSchema.infer(tokenRDD, header, parsedOptions)
+    case None =>
+      // If the first line could not be read, just return the empty schema.
+      StructType(Nil)
   }
 
   private def createBaseDataset(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 50503385ad..0b1e5dac2d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -26,7 +26,7 @@ import org.apache.commons.lang3.time.FastDateFormat
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes}
 
-private[csv] class CSVOptions(
+class CSVOptions(
     @transient private val parameters: CaseInsensitiveMap[String],
     defaultTimeZoneId: String,
     defaultColumnNameOfCorruptRecord: String)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index 3b3b87e435..e42ea3fa39 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
-private[csv] class UnivocityParser(
+class UnivocityParser(
     schema: StructType,
     requiredSchema: StructType,
     private val options: CSVOptions) extends Logging {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index eaedede349..4435e4df38 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -129,6 +129,22 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
     verifyCars(cars, withHeader = true, checkTypes = true)
   }
 
+  test("simple csv test with string dataset") {
+    val csvDataset = spark.read.text(testFile(carsFile)).as[String]
+    val cars = spark.read
+      .option("header", "true")
+      .option("inferSchema", "true")
+      .csv(csvDataset)
+
+    verifyCars(cars, withHeader = true, checkTypes = true)
+
+    val carsWithoutHeader = spark.read
+      .option("header", "false")
+      .csv(csvDataset)
+
+    verifyCars(carsWithoutHeader, withHeader = false, checkTypes = false)
+  }
+
   test("test inferring booleans") {
     val result = spark.read
       .format("csv")
@@ -1088,4 +1104,15 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
       checkAnswer(df, spark.emptyDataFrame)
     }
   }
+
+  test("Empty string dataset produces empty dataframe and keep user-defined schema") {
+    val df1 = spark.read.csv(spark.emptyDataset[String])
+    assert(df1.schema === spark.emptyDataFrame.schema)
+    checkAnswer(df1, spark.emptyDataFrame)
+
+    val schema = StructType(StructField("a", StringType) :: Nil)
+    val df2 = spark.read.schema(schema).csv(spark.emptyDataset[String])
+    assert(df2.schema === schema)
+  }
+
 }
-- 
GitLab