Skip to content
Snippets Groups Projects
Commit d6cbec75 authored by hyukjinkwon's avatar hyukjinkwon Committed by Wenchen Fan
Browse files

[SPARK-18943][SQL] Avoid per-record type dispatch in CSV when reading

## What changes were proposed in this pull request?

`CSVRelation.csvParser` does type dispatch for each value in each row. We can prevent this because the schema is already kept in `CSVRelation`.

So, this PR proposes that converters are created first according to the schema, and then apply them to each.

I just ran some small benchmarks as below after resembling the logics in https://github.com/apache/spark/blob/7c33b0fd050f3d2b08c1cfd7efbff8166832c1af/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala#L170-L178 to test the updated logics.

```scala
test("Benchmark for CSV converter") {
  var numMalformedRecords = 0
  val N = 500 << 12
  val schema = StructType(
    StructField("a", StringType) ::
    StructField("b", StringType) ::
    StructField("c", StringType) ::
    StructField("d", StringType) :: Nil)

  val row = Array("1.0", "test", "2015-08-20 14:57:00", "FALSE")
  val data = spark.sparkContext.parallelize(List.fill(N)(row))
  val parser = CSVRelation.csvParser(schema, schema.fieldNames, CSVOptions())

  val benchmark = new Benchmark("CSV converter", N)
  benchmark.addCase("cast CSV string tokens", 10) { _ =>
    data.flatMap { recordTokens =>
      parser(recordTokens, numMalformedRecords)
    }.collect()
  }
  benchmark.run()
}
```

**Before**

```
CSV converter:                           Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
cast CSV string tokens                        1061 / 1130          1.9         517.9       1.0X
```

**After**

```
CSV converter:                           Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
cast CSV string tokens                         940 / 1011          2.2         459.2       1.0X
```

## How was this patch tested?

Tests in `CSVTypeCastSuite` and `CSVRelation`

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #16351 from HyukjinKwon/type-dispatch.
parent f2ceb2ab
No related branches found
No related tags found
No related merge requests found
......@@ -215,84 +215,121 @@ private[csv] object CSVInferSchema {
}
private[csv] object CSVTypeCast {
// A `ValueConverter` is responsible for converting the given value to a desired type.
private type ValueConverter = String => Any
/**
* Casts given string datum to specified type.
* Currently we do not support complex types (ArrayType, MapType, StructType).
* Create converters which cast each given string datum to each specified type in given schema.
* Currently, we do not support complex types (`ArrayType`, `MapType`, `StructType`).
*
* For string types, this is simply the datum. For other types.
* For string types, this is simply the datum.
* For other types, this is converted into the value according to the type.
* For other nullable types, returns null if it is null or equals to the value specified
* in `nullValue` option.
*
* @param datum string value
* @param name field name in schema.
* @param castType data type to cast `datum` into.
* @param nullable nullability for the field.
* @param schema schema that contains data types to cast the given value into.
* @param options CSV options.
*/
def castTo(
def makeConverters(
schema: StructType,
options: CSVOptions = CSVOptions()): Array[ValueConverter] = {
schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray
}
/**
* Create a converter which converts the string value to a value according to a desired type.
*/
def makeConverter(
name: String,
dataType: DataType,
nullable: Boolean = true,
options: CSVOptions = CSVOptions()): ValueConverter = dataType match {
case _: ByteType => (d: String) =>
nullSafeDatum(d, name, nullable, options)(_.toByte)
case _: ShortType => (d: String) =>
nullSafeDatum(d, name, nullable, options)(_.toShort)
case _: IntegerType => (d: String) =>
nullSafeDatum(d, name, nullable, options)(_.toInt)
case _: LongType => (d: String) =>
nullSafeDatum(d, name, nullable, options)(_.toLong)
case _: FloatType => (d: String) =>
nullSafeDatum(d, name, nullable, options) {
case options.nanValue => Float.NaN
case options.negativeInf => Float.NegativeInfinity
case options.positiveInf => Float.PositiveInfinity
case datum =>
Try(datum.toFloat)
.getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).floatValue())
}
case _: DoubleType => (d: String) =>
nullSafeDatum(d, name, nullable, options) {
case options.nanValue => Double.NaN
case options.negativeInf => Double.NegativeInfinity
case options.positiveInf => Double.PositiveInfinity
case datum =>
Try(datum.toDouble)
.getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).doubleValue())
}
case _: BooleanType => (d: String) =>
nullSafeDatum(d, name, nullable, options)(_.toBoolean)
case dt: DecimalType => (d: String) =>
nullSafeDatum(d, name, nullable, options) { datum =>
val value = new BigDecimal(datum.replaceAll(",", ""))
Decimal(value, dt.precision, dt.scale)
}
case _: TimestampType => (d: String) =>
nullSafeDatum(d, name, nullable, options) { datum =>
// This one will lose microseconds parts.
// See https://issues.apache.org/jira/browse/SPARK-10681.
Try(options.timestampFormat.parse(datum).getTime * 1000L)
.getOrElse {
// If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
// compatibility.
DateTimeUtils.stringToTime(datum).getTime * 1000L
}
}
case _: DateType => (d: String) =>
nullSafeDatum(d, name, nullable, options) { datum =>
// This one will lose microseconds parts.
// See https://issues.apache.org/jira/browse/SPARK-10681.x
Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime))
.getOrElse {
// If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
// compatibility.
DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
}
}
case _: StringType => (d: String) =>
nullSafeDatum(d, name, nullable, options)(UTF8String.fromString(_))
case udt: UserDefinedType[_] => (datum: String) =>
makeConverter(name, udt.sqlType, nullable, options)
case _ => throw new RuntimeException(s"Unsupported type: ${dataType.typeName}")
}
private def nullSafeDatum(
datum: String,
name: String,
castType: DataType,
nullable: Boolean = true,
options: CSVOptions = CSVOptions()): Any = {
// datum can be null if the number of fields found is less than the length of the schema
nullable: Boolean,
options: CSVOptions)(converter: ValueConverter): Any = {
if (datum == options.nullValue || datum == null) {
if (!nullable) {
throw new RuntimeException(s"null value found but field $name is not nullable.")
}
null
} else {
castType match {
case _: ByteType => datum.toByte
case _: ShortType => datum.toShort
case _: IntegerType => datum.toInt
case _: LongType => datum.toLong
case _: FloatType =>
datum match {
case options.nanValue => Float.NaN
case options.negativeInf => Float.NegativeInfinity
case options.positiveInf => Float.PositiveInfinity
case _ =>
Try(datum.toFloat)
.getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).floatValue())
}
case _: DoubleType =>
datum match {
case options.nanValue => Double.NaN
case options.negativeInf => Double.NegativeInfinity
case options.positiveInf => Double.PositiveInfinity
case _ =>
Try(datum.toDouble)
.getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).doubleValue())
}
case _: BooleanType => datum.toBoolean
case dt: DecimalType =>
val value = new BigDecimal(datum.replaceAll(",", ""))
Decimal(value, dt.precision, dt.scale)
case _: TimestampType =>
// This one will lose microseconds parts.
// See https://issues.apache.org/jira/browse/SPARK-10681.
Try(options.timestampFormat.parse(datum).getTime * 1000L)
.getOrElse {
// If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
// compatibility.
DateTimeUtils.stringToTime(datum).getTime * 1000L
}
case _: DateType =>
// This one will lose microseconds parts.
// See https://issues.apache.org/jira/browse/SPARK-10681.x
Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime))
.getOrElse {
// If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
// compatibility.
DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
}
case _: StringType => UTF8String.fromString(datum)
case udt: UserDefinedType[_] => castTo(datum, name, udt.sqlType, nullable, options)
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
}
converter.apply(datum)
}
}
......
......@@ -69,26 +69,26 @@ object CSVRelation extends Logging {
schema: StructType,
requiredColumns: Array[String],
params: CSVOptions): (Array[String], Int) => Option[InternalRow] = {
val schemaFields = schema.fields
val requiredFields = StructType(requiredColumns.map(schema(_))).fields
val safeRequiredFields = if (params.dropMalformed) {
// If `dropMalformed` is enabled, then it needs to parse all the values
// so that we can decide which row is malformed.
requiredFields ++ schemaFields.filterNot(requiredFields.contains(_))
requiredFields ++ schema.filterNot(requiredFields.contains(_))
} else {
requiredFields
}
val safeRequiredIndices = new Array[Int](safeRequiredFields.length)
schemaFields.zipWithIndex.filter {
case (field, _) => safeRequiredFields.contains(field)
}.foreach {
case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index
schema.zipWithIndex.filter { case (field, _) =>
safeRequiredFields.contains(field)
}.foreach { case (field, index) =>
safeRequiredIndices(safeRequiredFields.indexOf(field)) = index
}
val requiredSize = requiredFields.length
val row = new GenericInternalRow(requiredSize)
val converters = CSVTypeCast.makeConverters(schema, params)
(tokens: Array[String], numMalformedRows) => {
if (params.dropMalformed && schemaFields.length != tokens.length) {
if (params.dropMalformed && schema.length != tokens.length) {
if (numMalformedRows < params.maxMalformedLogPerPartition) {
logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
}
......@@ -98,14 +98,14 @@ object CSVRelation extends Logging {
"found on this partition. Malformed records from now on will not be logged.")
}
None
} else if (params.failFast && schemaFields.length != tokens.length) {
} else if (params.failFast && schema.length != tokens.length) {
throw new RuntimeException(s"Malformed line in FAILFAST mode: " +
s"${tokens.mkString(params.delimiter.toString)}")
} else {
val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.length) {
tokens ++ new Array[String](schemaFields.length - tokens.length)
} else if (params.permissive && schemaFields.length < tokens.length) {
tokens.take(schemaFields.length)
val indexSafeTokens = if (params.permissive && schema.length > tokens.length) {
tokens ++ new Array[String](schema.length - tokens.length)
} else if (params.permissive && schema.length < tokens.length) {
tokens.take(schema.length)
} else {
tokens
}
......@@ -114,20 +114,14 @@ object CSVRelation extends Logging {
var subIndex: Int = 0
while (subIndex < safeRequiredIndices.length) {
index = safeRequiredIndices(subIndex)
val field = schemaFields(index)
// It anyway needs to try to parse since it decides if this row is malformed
// or not after trying to cast in `DROPMALFORMED` mode even if the casted
// value is not stored in the row.
val value = CSVTypeCast.castTo(
indexSafeTokens(index),
field.name,
field.dataType,
field.nullable,
params)
val value = converters(index).apply(indexSafeTokens(index))
if (subIndex < requiredSize) {
row(subIndex) = value
}
subIndex = subIndex + 1
subIndex += 1
}
Some(row)
} catch {
......
......@@ -36,7 +36,7 @@ class CSVTypeCastSuite extends SparkFunSuite {
stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) =>
val decimalValue = new BigDecimal(decimalVal.toString)
assert(CSVTypeCast.castTo(strVal, "_1", decimalType) ===
assert(CSVTypeCast.makeConverter("_1", decimalType).apply(strVal) ===
Decimal(decimalValue, decimalType.precision, decimalType.scale))
}
}
......@@ -66,92 +66,81 @@ class CSVTypeCastSuite extends SparkFunSuite {
}
test("Nullable types are handled") {
assertNull(
CSVTypeCast.castTo("-", "_1", ByteType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", "_1", ShortType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", "_1", IntegerType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", "_1", LongType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", "_1", FloatType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", "_1", DoubleType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", "_1", BooleanType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", "_1", DecimalType.DoubleDecimal, true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", "_1", TimestampType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", "_1", DateType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", "_1", StringType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo(null, "_1", IntegerType, nullable = true, CSVOptions("nullValue", "-")))
// casting a null to not nullable field should throw an exception.
var message = intercept[RuntimeException] {
CSVTypeCast.castTo(null, "_1", IntegerType, nullable = false, CSVOptions("nullValue", "-"))
}.getMessage
assert(message.contains("null value found but field _1 is not nullable."))
message = intercept[RuntimeException] {
CSVTypeCast.castTo("-", "_1", StringType, nullable = false, CSVOptions("nullValue", "-"))
}.getMessage
assert(message.contains("null value found but field _1 is not nullable."))
}
test("String type should also respect `nullValue`") {
assertNull(
CSVTypeCast.castTo("", "_1", StringType, nullable = true, CSVOptions()))
val types = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
BooleanType, DecimalType.DoubleDecimal, TimestampType, DateType, StringType)
// Nullable field with nullValue option.
types.foreach { t =>
// Tests that a custom nullValue.
val converter =
CSVTypeCast.makeConverter("_1", t, nullable = true, CSVOptions("nullValue", "-"))
assertNull(converter.apply("-"))
assertNull(converter.apply(null))
// Tests that the default nullValue is empty string.
assertNull(CSVTypeCast.makeConverter("_1", t, nullable = true).apply(""))
}
assert(
CSVTypeCast.castTo("", "_1", StringType, nullable = true, CSVOptions("nullValue", "null")) ==
UTF8String.fromString(""))
assert(
CSVTypeCast.castTo("", "_1", StringType, nullable = false, CSVOptions("nullValue", "null")) ==
UTF8String.fromString(""))
// Not nullable field with nullValue option.
types.foreach { t =>
// Casts a null to not nullable field should throw an exception.
val converter =
CSVTypeCast.makeConverter("_1", t, nullable = false, CSVOptions("nullValue", "-"))
var message = intercept[RuntimeException] {
converter.apply("-")
}.getMessage
assert(message.contains("null value found but field _1 is not nullable."))
message = intercept[RuntimeException] {
converter.apply(null)
}.getMessage
assert(message.contains("null value found but field _1 is not nullable."))
}
assertNull(
CSVTypeCast.castTo(null, "_1", StringType, nullable = true, CSVOptions("nullValue", "null")))
// If nullValue is different with empty string, then, empty string should not be casted into
// null.
Seq(true, false).foreach { b =>
val converter =
CSVTypeCast.makeConverter("_1", StringType, nullable = b, CSVOptions("nullValue", "null"))
assert(converter.apply("") == UTF8String.fromString(""))
}
}
test("Throws exception for empty string with non null type") {
val exception = intercept[RuntimeException]{
CSVTypeCast.castTo("", "_1", IntegerType, nullable = false, CSVOptions())
CSVTypeCast.makeConverter("_1", IntegerType, nullable = false, CSVOptions()).apply("")
}
assert(exception.getMessage.contains("null value found but field _1 is not nullable."))
}
test("Types are cast correctly") {
assert(CSVTypeCast.castTo("10", "_1", ByteType) == 10)
assert(CSVTypeCast.castTo("10", "_1", ShortType) == 10)
assert(CSVTypeCast.castTo("10", "_1", IntegerType) == 10)
assert(CSVTypeCast.castTo("10", "_1", LongType) == 10)
assert(CSVTypeCast.castTo("1.00", "_1", FloatType) == 1.0)
assert(CSVTypeCast.castTo("1.00", "_1", DoubleType) == 1.0)
assert(CSVTypeCast.castTo("true", "_1", BooleanType) == true)
assert(CSVTypeCast.makeConverter("_1", ByteType).apply("10") == 10)
assert(CSVTypeCast.makeConverter("_1", ShortType).apply("10") == 10)
assert(CSVTypeCast.makeConverter("_1", IntegerType).apply("10") == 10)
assert(CSVTypeCast.makeConverter("_1", LongType).apply("10") == 10)
assert(CSVTypeCast.makeConverter("_1", FloatType).apply("1.00") == 1.0)
assert(CSVTypeCast.makeConverter("_1", DoubleType).apply("1.00") == 1.0)
assert(CSVTypeCast.makeConverter("_1", BooleanType).apply("true") == true)
val timestampsOptions = CSVOptions("timestampFormat", "dd/MM/yyyy hh:mm")
val customTimestamp = "31/01/2015 00:00"
val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime
val castedTimestamp =
CSVTypeCast.castTo(customTimestamp, "_1", TimestampType, nullable = true, timestampsOptions)
CSVTypeCast.makeConverter("_1", TimestampType, nullable = true, timestampsOptions)
.apply(customTimestamp)
assert(castedTimestamp == expectedTime * 1000L)
val customDate = "31/01/2015"
val dateOptions = CSVOptions("dateFormat", "dd/MM/yyyy")
val expectedDate = dateOptions.dateFormat.parse(customDate).getTime
val castedDate =
CSVTypeCast.castTo(customTimestamp, "_1", DateType, nullable = true, dateOptions)
CSVTypeCast.makeConverter("_1", DateType, nullable = true, dateOptions)
.apply(customTimestamp)
assert(castedDate == DateTimeUtils.millisToDays(expectedDate))
val timestamp = "2015-01-01 00:00:00"
assert(CSVTypeCast.castTo(timestamp, "_1", TimestampType) ==
assert(CSVTypeCast.makeConverter("_1", TimestampType).apply(timestamp) ==
DateTimeUtils.stringToTime(timestamp).getTime * 1000L)
assert(CSVTypeCast.castTo("2015-01-01", "_1", DateType) ==
assert(CSVTypeCast.makeConverter("_1", DateType).apply("2015-01-01") ==
DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime))
}
......@@ -159,16 +148,18 @@ class CSVTypeCastSuite extends SparkFunSuite {
val originalLocale = Locale.getDefault
try {
Locale.setDefault(new Locale("fr", "FR"))
assert(CSVTypeCast.castTo("1,00", "_1", FloatType) == 100.0) // Would parse as 1.0 in fr-FR
assert(CSVTypeCast.castTo("1,00", "_1", DoubleType) == 100.0)
// Would parse as 1.0 in fr-FR
assert(CSVTypeCast.makeConverter("_1", FloatType).apply("1,00") == 100.0)
assert(CSVTypeCast.makeConverter("_1", DoubleType).apply("1,00") == 100.0)
} finally {
Locale.setDefault(originalLocale)
}
}
test("Float NaN values are parsed correctly") {
val floatVal: Float = CSVTypeCast.castTo(
"nn", "_1", FloatType, nullable = true, CSVOptions("nanValue", "nn")).asInstanceOf[Float]
val floatVal: Float = CSVTypeCast.makeConverter(
"_1", FloatType, nullable = true, CSVOptions("nanValue", "nn")
).apply("nn").asInstanceOf[Float]
// Java implements the IEEE-754 floating point standard which guarantees that any comparison
// against NaN will return false (except != which returns true)
......@@ -176,34 +167,37 @@ class CSVTypeCastSuite extends SparkFunSuite {
}
test("Double NaN values are parsed correctly") {
val doubleVal: Double = CSVTypeCast.castTo(
"-", "_1", DoubleType, nullable = true, CSVOptions("nanValue", "-")).asInstanceOf[Double]
val doubleVal: Double = CSVTypeCast.makeConverter(
"_1", DoubleType, nullable = true, CSVOptions("nanValue", "-")
).apply("-").asInstanceOf[Double]
assert(doubleVal.isNaN)
}
test("Float infinite values can be parsed") {
val floatVal1 = CSVTypeCast.castTo(
"max", "_1", FloatType, nullable = true, CSVOptions("negativeInf", "max")).asInstanceOf[Float]
val floatVal1 = CSVTypeCast.makeConverter(
"_1", FloatType, nullable = true, CSVOptions("negativeInf", "max")
).apply("max").asInstanceOf[Float]
assert(floatVal1 == Float.NegativeInfinity)
val floatVal2 = CSVTypeCast.castTo(
"max", "_1", FloatType, nullable = true, CSVOptions("positiveInf", "max")).asInstanceOf[Float]
val floatVal2 = CSVTypeCast.makeConverter(
"_1", FloatType, nullable = true, CSVOptions("positiveInf", "max")
).apply("max").asInstanceOf[Float]
assert(floatVal2 == Float.PositiveInfinity)
}
test("Double infinite values can be parsed") {
val doubleVal1 = CSVTypeCast.castTo(
"max", "_1", DoubleType, nullable = true, CSVOptions("negativeInf", "max")
).asInstanceOf[Double]
val doubleVal1 = CSVTypeCast.makeConverter(
"_1", DoubleType, nullable = true, CSVOptions("negativeInf", "max")
).apply("max").asInstanceOf[Double]
assert(doubleVal1 == Double.NegativeInfinity)
val doubleVal2 = CSVTypeCast.castTo(
"max", "_1", DoubleType, nullable = true, CSVOptions("positiveInf", "max")
).asInstanceOf[Double]
val doubleVal2 = CSVTypeCast.makeConverter(
"_1", DoubleType, nullable = true, CSVOptions("positiveInf", "max")
).apply("max").asInstanceOf[Double]
assert(doubleVal2 == Double.PositiveInfinity)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment