Skip to content
Snippets Groups Projects
Commit 5f83c699 authored by Hossein's avatar Hossein Committed by Reynold Xin
Browse files

[SPARK-12833][SQL] Initial import of spark-csv

CSV is the most common data format in the "small data" world. It is often the first format people want to try when they see Spark on a single node. Having to rely on a 3rd party component for this leads to poor user experience for new users. This PR merges the popular spark-csv data source package (https://github.com/databricks/spark-csv) with SparkSQL.

This is a first PR to bring the functionality to spark 2.0 master. We will complete items outlines in the design document (see JIRA attachment) in follow up pull requests.

Author: Hossein <hossein@databricks.com>
Author: Reynold Xin <rxin@databricks.com>

Closes #10766 from rxin/csv.
parent c5e7076d
No related branches found
No related tags found
No related merge requests found
Showing
with 1010 additions and 8 deletions
......@@ -86,3 +86,5 @@ org.apache.spark.scheduler.SparkHistoryListenerFactory
.*parquet
LZ4BlockInputStream.java
spark-deps-.*
.*csv
.*tsv
......@@ -610,7 +610,43 @@ Vis.js uses and redistributes the following third-party libraries:
===============================================================================
The CSS style for the navigation sidebar of the documentation was originally
The CSS style for the navigation sidebar of the documentation was originally
submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project
is distributed under the 3-Clause BSD license.
===============================================================================
For CSV functionality:
/*
* Copyright 2014 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Copyright 2015 Ayasdi Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
......@@ -184,6 +184,7 @@ tachyon-underfs-hdfs-0.8.2.jar
tachyon-underfs-local-0.8.2.jar
tachyon-underfs-s3-0.8.2.jar
uncommons-maths-1.2.2a.jar
univocity-parsers-1.5.6.jar
unused-1.0.0.jar
xbean-asm5-shaded-4.4.jar
xmlenc-0.52.jar
......
......@@ -175,6 +175,7 @@ tachyon-underfs-hdfs-0.8.2.jar
tachyon-underfs-local-0.8.2.jar
tachyon-underfs-s3-0.8.2.jar
uncommons-maths-1.2.2a.jar
univocity-parsers-1.5.6.jar
unused-1.0.0.jar
xbean-asm5-shaded-4.4.jar
xmlenc-0.52.jar
......
......@@ -176,6 +176,7 @@ tachyon-underfs-hdfs-0.8.2.jar
tachyon-underfs-local-0.8.2.jar
tachyon-underfs-s3-0.8.2.jar
uncommons-maths-1.2.2a.jar
univocity-parsers-1.5.6.jar
unused-1.0.0.jar
xbean-asm5-shaded-4.4.jar
xmlenc-0.52.jar
......
......@@ -182,6 +182,7 @@ tachyon-underfs-hdfs-0.8.2.jar
tachyon-underfs-local-0.8.2.jar
tachyon-underfs-s3-0.8.2.jar
uncommons-maths-1.2.2a.jar
univocity-parsers-1.5.6.jar
unused-1.0.0.jar
xbean-asm5-shaded-4.4.jar
xercesImpl-2.9.1.jar
......
......@@ -36,6 +36,12 @@
</properties>
<dependencies>
<dependency>
<groupId>com.univocity</groupId>
<artifactId>univocity-parsers</artifactId>
<version>1.5.6</version>
<type>jar</type>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
......
org.apache.spark.sql.execution.datasources.csv.DefaultSource
org.apache.spark.sql.execution.datasources.jdbc.DefaultSource
org.apache.spark.sql.execution.datasources.json.DefaultSource
org.apache.spark.sql.execution.datasources.parquet.DefaultSource
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.csv
import java.math.BigDecimal
import java.sql.{Date, Timestamp}
import java.text.NumberFormat
import java.util.Locale
import scala.util.control.Exception._
import scala.util.Try
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.types._
private[sql] object CSVInferSchema {
/**
* Similar to the JSON schema inference
* 1. Infer type of each row
* 2. Merge row types to find common type
* 3. Replace any null types with string type
* TODO(hossein): Can we reuse JSON schema inference? [SPARK-12670]
*/
def apply(
tokenRdd: RDD[Array[String]],
header: Array[String],
nullValue: String = ""): StructType = {
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val rootTypes: Array[DataType] =
tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes)
val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
StructField(thisHeader, rootType, nullable = true)
}
StructType(structFields)
}
private def inferRowType(nullValue: String)
(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
var i = 0
while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing.
rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue)
i+=1
}
rowSoFar
}
private[csv] def mergeRowTypes(
first: Array[DataType],
second: Array[DataType]): Array[DataType] = {
first.zipAll(second, NullType, NullType).map { case ((a, b)) =>
val tpe = findTightestCommonType(a, b).getOrElse(StringType)
tpe match {
case _: NullType => StringType
case other => other
}
}
}
/**
* Infer type of string field. Given known type Double, and a string "1", there is no
* point checking if it is an Int, as the final type must be Double or higher.
*/
private[csv] def inferField(
typeSoFar: DataType, field: String, nullValue: String = ""): DataType = {
if (field == null || field.isEmpty || field == nullValue) {
typeSoFar
} else {
typeSoFar match {
case NullType => tryParseInteger(field)
case IntegerType => tryParseInteger(field)
case LongType => tryParseLong(field)
case DoubleType => tryParseDouble(field)
case TimestampType => tryParseTimestamp(field)
case StringType => StringType
case other: DataType =>
throw new UnsupportedOperationException(s"Unexpected data type $other")
}
}
}
private def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) {
IntegerType
} else {
tryParseLong(field)
}
private def tryParseLong(field: String): DataType = if ((allCatch opt field.toLong).isDefined) {
LongType
} else {
tryParseDouble(field)
}
private def tryParseDouble(field: String): DataType = {
if ((allCatch opt field.toDouble).isDefined) {
DoubleType
} else {
tryParseTimestamp(field)
}
}
def tryParseTimestamp(field: String): DataType = {
if ((allCatch opt Timestamp.valueOf(field)).isDefined) {
TimestampType
} else {
stringType()
}
}
// Defining a function to return the StringType constant is necessary in order to work around
// a Scala compiler issue which leads to runtime incompatibilities with certain Spark versions;
// see issue #128 for more details.
private def stringType(): DataType = {
StringType
}
private val numericPrecedence: IndexedSeq[DataType] = HiveTypeCoercion.numericPrecedence
/**
* Copied from internal Spark api
* [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
*/
val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
Some(numericPrecedence(index))
case _ => None
}
}
object CSVTypeCast {
/**
* Casts given string datum to specified type.
* Currently we do not support complex types (ArrayType, MapType, StructType).
*
* For string types, this is simply the datum. For other types.
* For other nullable types, this is null if the string datum is empty.
*
* @param datum string value
* @param castType SparkSQL type
*/
private[csv] def castTo(
datum: String,
castType: DataType,
nullable: Boolean = true,
nullValue: String = ""): Any = {
if (datum == nullValue && nullable && (!castType.isInstanceOf[StringType])) {
null
} else {
castType match {
case _: ByteType => datum.toByte
case _: ShortType => datum.toShort
case _: IntegerType => datum.toInt
case _: LongType => datum.toLong
case _: FloatType => Try(datum.toFloat)
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue())
case _: DoubleType => Try(datum.toDouble)
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
case _: BooleanType => datum.toBoolean
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
// TODO(hossein): would be good to support other common timestamp formats
case _: TimestampType => Timestamp.valueOf(datum)
// TODO(hossein): would be good to support other common date formats
case _: DateType => Date.valueOf(datum)
case _: StringType => datum
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
}
}
}
/**
* Helper method that converts string representation of a character to actual character.
* It handles some Java escaped strings and throws exception if given string is longer than one
* character.
*
*/
@throws[IllegalArgumentException]
private[csv] def toChar(str: String): Char = {
if (str.charAt(0) == '\\') {
str.charAt(1)
match {
case 't' => '\t'
case 'r' => '\r'
case 'b' => '\b'
case 'f' => '\f'
case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options
case '\'' => '\''
case 'u' if str == """\u0000""" => '\u0000'
case _ =>
throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str")
}
} else if (str.length == 1) {
str.charAt(0)
} else {
throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str")
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.csv
import java.nio.charset.Charset
import org.apache.spark.Logging
private[sql] case class CSVParameters(parameters: Map[String, String]) extends Logging {
private def getChar(paramName: String, default: Char): Char = {
val paramValue = parameters.get(paramName)
paramValue match {
case None => default
case Some(value) if value.length == 0 => '\0'
case Some(value) if value.length == 1 => value.charAt(0)
case _ => throw new RuntimeException(s"$paramName cannot be more than one character")
}
}
private def getBool(paramName: String, default: Boolean = false): Boolean = {
val param = parameters.getOrElse(paramName, default.toString)
if (param.toLowerCase() == "true") {
true
} else if (param.toLowerCase == "false") {
false
} else {
throw new Exception(s"$paramName flag can be true or false")
}
}
val delimiter = CSVTypeCast.toChar(parameters.getOrElse("delimiter", ","))
val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
val charset = parameters.getOrElse("charset", Charset.forName("UTF-8").name())
val quote = getChar("quote", '\"')
val escape = getChar("escape", '\\')
val comment = getChar("comment", '\0')
val headerFlag = getBool("header")
val inferSchemaFlag = getBool("inferSchema")
val ignoreLeadingWhiteSpaceFlag = getBool("ignoreLeadingWhiteSpace")
val ignoreTrailingWhiteSpaceFlag = getBool("ignoreTrailingWhiteSpace")
// Limit the number of lines we'll search for a header row that isn't comment-prefixed
val MAX_COMMENT_LINES_IN_HEADER = 10
// Parse mode flags
if (!ParseModes.isValidMode(parseMode)) {
logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.")
}
val failFast = ParseModes.isFailFastMode(parseMode)
val dropMalformed = ParseModes.isDropMalformedMode(parseMode)
val permissive = ParseModes.isPermissiveMode(parseMode)
val nullValue = parameters.getOrElse("nullValue", "")
val maxColumns = 20480
val maxCharsPerColumn = 100000
val inputBufferSize = 128
val isCommentSet = this.comment != '\0'
val rowSeparator = "\n"
}
private[csv] object ParseModes {
val PERMISSIVE_MODE = "PERMISSIVE"
val DROP_MALFORMED_MODE = "DROPMALFORMED"
val FAIL_FAST_MODE = "FAILFAST"
val DEFAULT = PERMISSIVE_MODE
def isValidMode(mode: String): Boolean = {
mode.toUpperCase match {
case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true
case _ => false
}
}
def isDropMalformedMode(mode: String): Boolean = mode.toUpperCase == DROP_MALFORMED_MODE
def isFailFastMode(mode: String): Boolean = mode.toUpperCase == FAIL_FAST_MODE
def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) {
mode.toUpperCase == PERMISSIVE_MODE
} else {
true // We default to permissive is the mode string is not valid
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.csv
import java.io.{ByteArrayOutputStream, OutputStreamWriter, StringReader}
import com.univocity.parsers.csv.{CsvParser, CsvParserSettings, CsvWriter, CsvWriterSettings}
import org.apache.spark.Logging
/**
* Read and parse CSV-like input
*
* @param params Parameters object
* @param headers headers for the columns
*/
private[sql] abstract class CsvReader(params: CSVParameters, headers: Seq[String]) {
protected lazy val parser: CsvParser = {
val settings = new CsvParserSettings()
val format = settings.getFormat
format.setDelimiter(params.delimiter)
format.setLineSeparator(params.rowSeparator)
format.setQuote(params.quote)
format.setQuoteEscape(params.escape)
format.setComment(params.comment)
settings.setIgnoreLeadingWhitespaces(params.ignoreLeadingWhiteSpaceFlag)
settings.setIgnoreTrailingWhitespaces(params.ignoreTrailingWhiteSpaceFlag)
settings.setReadInputOnSeparateThread(false)
settings.setInputBufferSize(params.inputBufferSize)
settings.setMaxColumns(params.maxColumns)
settings.setNullValue(params.nullValue)
settings.setMaxCharsPerColumn(params.maxCharsPerColumn)
if (headers != null) settings.setHeaders(headers: _*)
new CsvParser(settings)
}
}
/**
* Converts a sequence of string to CSV string
*
* @param params Parameters object for configuration
* @param headers headers for columns
*/
private[sql] class LineCsvWriter(params: CSVParameters, headers: Seq[String]) extends Logging {
private val writerSettings = new CsvWriterSettings
private val format = writerSettings.getFormat
format.setDelimiter(params.delimiter)
format.setLineSeparator(params.rowSeparator)
format.setQuote(params.quote)
format.setQuoteEscape(params.escape)
format.setComment(params.comment)
writerSettings.setNullValue(params.nullValue)
writerSettings.setEmptyValue(params.nullValue)
writerSettings.setSkipEmptyLines(true)
writerSettings.setQuoteAllFields(false)
writerSettings.setHeaders(headers: _*)
def writeRow(row: Seq[String], includeHeader: Boolean): String = {
val buffer = new ByteArrayOutputStream()
val outputWriter = new OutputStreamWriter(buffer)
val writer = new CsvWriter(outputWriter, writerSettings)
if (includeHeader) {
writer.writeHeaders()
}
writer.writeRow(row.toArray: _*)
writer.close()
buffer.toString.stripLineEnd
}
}
/**
* Parser for parsing a line at a time. Not efficient for bulk data.
*
* @param params Parameters object
*/
private[sql] class LineCsvReader(params: CSVParameters)
extends CsvReader(params, null) {
/**
* parse a line
*
* @param line a String with no newline at the end
* @return array of strings where each string is a field in the CSV record
*/
def parseLine(line: String): Array[String] = {
parser.beginParsing(new StringReader(line))
val parsed = parser.parseNext()
parser.stopParsing()
parsed
}
}
/**
* Parser for parsing lines in bulk. Use this when efficiency is desired.
*
* @param iter iterator over lines in the file
* @param params Parameters object
* @param headers headers for the columns
*/
private[sql] class BulkCsvReader(
iter: Iterator[String],
params: CSVParameters,
headers: Seq[String])
extends CsvReader(params, headers) with Iterator[Array[String]] {
private val reader = new StringIteratorReader(iter)
parser.beginParsing(reader)
private var nextRecord = parser.parseNext()
/**
* get the next parsed line.
* @return array of strings where each string is a field in the CSV record
*/
override def next(): Array[String] = {
val curRecord = nextRecord
if(curRecord != null) {
nextRecord = parser.parseNext()
} else {
throw new NoSuchElementException("next record is null")
}
curRecord
}
override def hasNext: Boolean = nextRecord != null
}
/**
* A Reader that "reads" from a sequence of lines. Spark's textFile method removes newlines at
* end of each line Univocity parser requires a Reader that provides access to the data to be
* parsed and needs the newlines to be present
* @param iter iterator over RDD[String]
*/
private class StringIteratorReader(val iter: Iterator[String]) extends java.io.Reader {
private var next: Long = 0
private var length: Long = 0 // length of input so far
private var start: Long = 0
private var str: String = null // current string from iter
/**
* fetch next string from iter, if done with current one
* pretend there is a new line at the end of every string we get from from iter
*/
private def refill(): Unit = {
if (length == next) {
if (iter.hasNext) {
str = iter.next()
start = length
length += (str.length + 1) // allowance for newline removed by SparkContext.textFile()
} else {
str = null
}
}
}
/**
* read the next character, if at end of string pretend there is a new line
*/
override def read(): Int = {
refill()
if (next >= length) {
-1
} else {
val cur = next - start
next += 1
if (cur == str.length) '\n' else str.charAt(cur.toInt)
}
}
/**
* read from str into cbuf
*/
override def read(cbuf: Array[Char], off: Int, len: Int): Int = {
refill()
var n = 0
if ((off < 0) || (off > cbuf.length) || (len < 0) ||
((off + len) > cbuf.length) || ((off + len) < 0)) {
throw new IndexOutOfBoundsException()
} else if (len == 0) {
n = 0
} else {
if (next >= length) { // end of input
n = -1
} else {
n = Math.min(length - next, len).toInt // lesser of amount of input available or buf size
if (n == length - next) {
str.getChars((next - start).toInt, (next - start + n - 1).toInt, cbuf, off)
cbuf(off + n - 1) = '\n'
} else {
str.getChars((next - start).toInt, (next - start + n).toInt, cbuf, off)
}
next += n
if (n < len) {
val m = read(cbuf, off + n, len - n) // have more space, fetch more input from iter
if(m != -1) n += m
}
}
}
n
}
override def skip(ns: Long): Long = {
throw new IllegalArgumentException("Skip not implemented")
}
override def ready: Boolean = {
refill()
true
}
override def markSupported: Boolean = false
override def mark(readAheadLimit: Int): Unit = {
throw new IllegalArgumentException("Mark not implemented")
}
override def reset(): Unit = {
throw new IllegalArgumentException("Mark and hence reset not implemented")
}
override def close(): Unit = { }
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.csv
import java.nio.charset.Charset
import scala.util.control.NonFatal
import com.google.common.base.Objects
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.hadoop.mapreduce.RecordWriter
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
private[csv] class CSVRelation(
private val inputRDD: Option[RDD[String]],
override val paths: Array[String],
private val maybeDataSchema: Option[StructType],
override val userDefinedPartitionColumns: Option[StructType],
private val parameters: Map[String, String])
(@transient val sqlContext: SQLContext) extends HadoopFsRelation with Serializable {
override lazy val dataSchema: StructType = maybeDataSchema match {
case Some(structType) => structType
case None => inferSchema(paths)
}
private val params = new CSVParameters(parameters)
@transient
private var cachedRDD: Option[RDD[String]] = None
private def readText(location: String): RDD[String] = {
if (Charset.forName(params.charset) == Charset.forName("UTF-8")) {
sqlContext.sparkContext.textFile(location)
} else {
sqlContext.sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](location)
.mapPartitions { _.map { pair =>
new String(pair._2.getBytes, 0, pair._2.getLength, params.charset)
}
}
}
}
private def baseRdd(inputPaths: Array[String]): RDD[String] = {
inputRDD.getOrElse {
cachedRDD.getOrElse {
val rdd = readText(inputPaths.mkString(","))
cachedRDD = Some(rdd)
rdd
}
}
}
private def tokenRdd(header: Array[String], inputPaths: Array[String]): RDD[Array[String]] = {
val rdd = baseRdd(inputPaths)
// Make sure firstLine is materialized before sending to executors
val firstLine = if (params.headerFlag) findFirstLine(rdd) else null
CSVRelation.univocityTokenizer(rdd, header, firstLine, params)
}
/**
* This supports to eliminate unneeded columns before producing an RDD
* containing all of its tuples as Row objects. This reads all the tokens of each line
* and then drop unneeded tokens without casting and type-checking by mapping
* both the indices produced by `requiredColumns` and the ones of tokens.
* TODO: Switch to using buildInternalScan
*/
override def buildScan(requiredColumns: Array[String], inputs: Array[FileStatus]): RDD[Row] = {
val pathsString = inputs.map(_.getPath.toUri.toString)
val header = schema.fields.map(_.name)
val tokenizedRdd = tokenRdd(header, pathsString)
CSVRelation.parseCsv(tokenizedRdd, schema, requiredColumns, inputs, sqlContext, params)
}
override def prepareJobForWrite(job: Job): OutputWriterFactory = {
new CSVOutputWriterFactory(params)
}
override def hashCode(): Int = Objects.hashCode(paths.toSet, dataSchema, schema, partitionColumns)
override def equals(other: Any): Boolean = other match {
case that: CSVRelation => {
val equalPath = paths.toSet == that.paths.toSet
val equalDataSchema = dataSchema == that.dataSchema
val equalSchema = schema == that.schema
val equalPartitionColums = partitionColumns == that.partitionColumns
equalPath && equalDataSchema && equalSchema && equalPartitionColums
}
case _ => false
}
private def inferSchema(paths: Array[String]): StructType = {
val rdd = baseRdd(Array(paths.head))
val firstLine = findFirstLine(rdd)
val firstRow = new LineCsvReader(params).parseLine(firstLine)
val header = if (params.headerFlag) {
firstRow
} else {
firstRow.zipWithIndex.map { case (value, index) => s"C$index" }
}
val parsedRdd = tokenRdd(header, paths)
if (params.inferSchemaFlag) {
CSVInferSchema(parsedRdd, header, params.nullValue)
} else {
// By default fields are assumed to be StringType
val schemaFields = header.map { fieldName =>
StructField(fieldName.toString, StringType, nullable = true)
}
StructType(schemaFields)
}
}
/**
* Returns the first line of the first non-empty file in path
*/
private def findFirstLine(rdd: RDD[String]): String = {
if (params.isCommentSet) {
rdd.take(params.MAX_COMMENT_LINES_IN_HEADER)
.find(!_.startsWith(params.comment.toString))
.getOrElse(sys.error(s"No uncommented header line in " +
s"first ${params.MAX_COMMENT_LINES_IN_HEADER} lines"))
} else {
rdd.first()
}
}
}
object CSVRelation extends Logging {
def univocityTokenizer(
file: RDD[String],
header: Seq[String],
firstLine: String,
params: CSVParameters): RDD[Array[String]] = {
// If header is set, make sure firstLine is materialized before sending to executors.
file.mapPartitionsWithIndex({
case (split, iter) => new BulkCsvReader(
if (params.headerFlag) iter.filterNot(_ == firstLine) else iter,
params,
headers = header)
}, true)
}
def parseCsv(
tokenizedRDD: RDD[Array[String]],
schema: StructType,
requiredColumns: Array[String],
inputs: Array[FileStatus],
sqlContext: SQLContext,
params: CSVParameters): RDD[Row] = {
val schemaFields = schema.fields
val requiredFields = StructType(requiredColumns.map(schema(_))).fields
val safeRequiredFields = if (params.dropMalformed) {
// If `dropMalformed` is enabled, then it needs to parse all the values
// so that we can decide which row is malformed.
requiredFields ++ schemaFields.filterNot(requiredFields.contains(_))
} else {
requiredFields
}
if (requiredColumns.isEmpty) {
sqlContext.sparkContext.emptyRDD[Row]
} else {
val safeRequiredIndices = new Array[Int](safeRequiredFields.length)
schemaFields.zipWithIndex.filter {
case (field, _) => safeRequiredFields.contains(field)
}.foreach {
case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index
}
val rowArray = new Array[Any](safeRequiredIndices.length)
val requiredSize = requiredFields.length
tokenizedRDD.flatMap { tokens =>
if (params.dropMalformed && schemaFields.length != tokens.size) {
logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
None
} else if (params.failFast && schemaFields.length != tokens.size) {
throw new RuntimeException(s"Malformed line in FAILFAST mode: " +
s"${tokens.mkString(params.delimiter.toString)}")
} else {
val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.size) {
tokens ++ new Array[String](schemaFields.length - tokens.size)
} else if (params.permissive && schemaFields.length < tokens.size) {
tokens.take(schemaFields.length)
} else {
tokens
}
try {
var index: Int = 0
var subIndex: Int = 0
while (subIndex < safeRequiredIndices.length) {
index = safeRequiredIndices(subIndex)
val field = schemaFields(index)
rowArray(subIndex) = CSVTypeCast.castTo(
indexSafeTokens(index),
field.dataType,
field.nullable,
params.nullValue)
subIndex = subIndex + 1
}
Some(Row.fromSeq(rowArray.take(requiredSize)))
} catch {
case NonFatal(e) if params.dropMalformed =>
logWarning("Parse exception. " +
s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
None
}
}
}
}
}
}
private[sql] class CSVOutputWriterFactory(params: CSVParameters) extends OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new CsvOutputWriter(path, dataSchema, context, params)
}
}
private[sql] class CsvOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext,
params: CSVParameters) extends OutputWriter with Logging {
// create the Generator without separator inserted between 2 records
private[this] val text = new Text()
private val recordWriter: RecordWriter[NullWritable, Text] = {
new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration
val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID")
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension")
}
}.getRecordWriter(context)
}
private var firstRow: Boolean = params.headerFlag
private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq)
private def rowToString(row: Seq[Any]): Seq[String] = row.map { field =>
if (field != null) {
field.toString
} else {
params.nullValue
}
}
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
override protected[sql] def writeInternal(row: InternalRow): Unit = {
// TODO: Instead of converting and writing every row, we should use the univocity buffer
val resultString = csvWriter.writeRow(rowToString(row.toSeq(dataSchema)), firstRow)
if (firstRow) {
firstRow = false
}
text.set(resultString)
recordWriter.write(NullWritable.get(), text)
}
override def close(): Unit = {
recordWriter.close(context)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.csv
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
/**
* Provides access to CSV data from pure SQL statements.
*/
class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
override def shortName(): String = "csv"
/**
* Creates a new relation for data store in CSV given parameters and user supported schema.
*/
override def createRelation(
sqlContext: SQLContext,
paths: Array[String],
dataSchema: Option[StructType],
partitionColumns: Option[StructType],
parameters: Map[String, String]): HadoopFsRelation = {
new CSVRelation(
None,
paths,
dataSchema,
partitionColumns,
parameters)(sqlContext)
}
}
......@@ -145,7 +145,7 @@ private[json] object InferSchema {
/**
* Convert NullType to StringType and remove StructTypes with no fields
*/
private def canonicalizeType: DataType => Option[DataType] = {
private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match {
case at @ ArrayType(elementType, _) =>
for {
canonicalType <- canonicalizeType(elementType)
......@@ -154,15 +154,15 @@ private[json] object InferSchema {
}
case StructType(fields) =>
val canonicalFields = for {
val canonicalFields: Array[StructField] = for {
field <- fields
if field.name.nonEmpty
if field.name.length > 0
canonicalType <- canonicalizeType(field.dataType)
} yield {
field.copy(dataType = canonicalType)
}
if (canonicalFields.nonEmpty) {
if (canonicalFields.length > 0) {
Some(StructType(canonicalFields))
} else {
// per SPARK-8093: empty structs should be deleted
......@@ -217,10 +217,9 @@ private[json] object InferSchema {
(t1, t2) match {
// Double support larger range than fixed decimal, DecimalType.Maximum should be enough
// in most case, also have better precision.
case (DoubleType, t: DecimalType) =>
DoubleType
case (t: DecimalType, DoubleType) =>
case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
DoubleType
case (t1: DecimalType, t2: DecimalType) =>
val scale = math.max(t1.scale, t2.scale)
val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
......
year|make|model|comment|blank
'2012'|'Tesla'|'S'| 'No comment'|
1997|Ford|E350|'Go get one now they are going fast'|
2015|Chevy|Volt
year,make,model,comment,blank
"2012","Tesla","S",null,
1997,Ford,E350,"Go get one now they are going fast",
null,Chevy,Volt
year,make,model,comment,blank
"2012,Tesla,S,No comment
1997,Ford,E350,Go get one now they are going fast"
"2015,"Chevy",Volt,
year,make,model,comment,blank
"2012","Tesla","S","No comment",
1997,Ford,E350,"Go get one now they are going fast",
2015,Chevy,Volt
year make model price comment blank
2012 Tesla S "80,000.65"
1997 Ford E350 35,000 "Go get one now they are going fast"
2015 Chevy Volt 5,000.10
yearmakemodelcommentblank
"2012""Tesla""S""No comment"
1997FordE350"Go get one now they are oing fast"
2015ChevyVolt
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment