Skip to content
Snippets Groups Projects
Commit 51841d77 authored by hyukjinkwon's avatar hyukjinkwon Committed by Davies Liu
Browse files

[SPARK-13866] [SQL] Handle decimal type in CSV inference at CSV data source.

## What changes were proposed in this pull request?

https://issues.apache.org/jira/browse/SPARK-13866

This PR adds the support to infer `DecimalType`.
Here are the rules between `IntegerType`, `LongType` and `DecimalType`.

#### Infering Types

1. `IntegerType` and then `LongType`are tried first.

  ```scala
  Int.MaxValue => IntegerType
  Long.MaxValue => LongType
  ```

2. If it fails, try `DecimalType`.

  ```scala
  (Long.MaxValue + 1) => DecimalType(20, 0)
  ```
  This does not try to infer this as `DecimalType` when scale is less than 0.

3. if it fails, try `DoubleType`
  ```scala
  0.1 => DoubleType // This is failed to be inferred as `DecimalType` because it has the scale, 1.
  ```

#### Compatible Types (Merging Types)

For merging types, this is the same with JSON data source. If `DecimalType` is not capable, then it becomes `DoubleType`

## How was this patch tested?

Unit tests were used and `./dev/run_tests` for code style test.

Author: hyukjinkwon <gurwls223@gmail.com>
Author: Hyukjin Kwon <gurwls223@gmail.com>

Closes #11724 from HyukjinKwon/SPARK-13866.
parent eda2800d
No related branches found
No related tags found
No related merge requests found
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources.csv package org.apache.spark.sql.execution.datasources.csv
import java.math.BigDecimal import java.math.BigDecimal
import java.text.{NumberFormat, SimpleDateFormat} import java.text.NumberFormat
import java.util.Locale import java.util.Locale
import scala.util.control.Exception._ import scala.util.control.Exception._
...@@ -85,6 +85,7 @@ private[csv] object CSVInferSchema { ...@@ -85,6 +85,7 @@ private[csv] object CSVInferSchema {
case NullType => tryParseInteger(field, options) case NullType => tryParseInteger(field, options)
case IntegerType => tryParseInteger(field, options) case IntegerType => tryParseInteger(field, options)
case LongType => tryParseLong(field, options) case LongType => tryParseLong(field, options)
case _: DecimalType => tryParseDecimal(field, options)
case DoubleType => tryParseDouble(field, options) case DoubleType => tryParseDouble(field, options)
case TimestampType => tryParseTimestamp(field, options) case TimestampType => tryParseTimestamp(field, options)
case BooleanType => tryParseBoolean(field, options) case BooleanType => tryParseBoolean(field, options)
...@@ -107,10 +108,28 @@ private[csv] object CSVInferSchema { ...@@ -107,10 +108,28 @@ private[csv] object CSVInferSchema {
if ((allCatch opt field.toLong).isDefined) { if ((allCatch opt field.toLong).isDefined) {
LongType LongType
} else { } else {
tryParseDouble(field, options) tryParseDecimal(field, options)
} }
} }
private def tryParseDecimal(field: String, options: CSVOptions): DataType = {
val decimalTry = allCatch opt {
// `BigDecimal` conversion can fail when the `field` is not a form of number.
val bigDecimal = new BigDecimal(field)
// Because many other formats do not support decimal, it reduces the cases for
// decimals by disallowing values having scale (eg. `1.1`).
if (bigDecimal.scale <= 0) {
// `DecimalType` conversion can fail when
// 1. The precision is bigger than 38.
// 2. scale is bigger than precision.
DecimalType(bigDecimal.precision, bigDecimal.scale)
} else {
tryParseDouble(field, options)
}
}
decimalTry.getOrElse(tryParseDouble(field, options))
}
private def tryParseDouble(field: String, options: CSVOptions): DataType = { private def tryParseDouble(field: String, options: CSVOptions): DataType = {
if ((allCatch opt field.toDouble).isDefined) { if ((allCatch opt field.toDouble).isDefined) {
DoubleType DoubleType
...@@ -170,6 +189,33 @@ private[csv] object CSVInferSchema { ...@@ -170,6 +189,33 @@ private[csv] object CSVInferSchema {
val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
Some(numericPrecedence(index)) Some(numericPrecedence(index))
// These two cases below deal with when `DecimalType` is larger than `IntegralType`.
case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) =>
Some(t2)
case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) =>
Some(t1)
// These two cases below deal with when `IntegralType` is larger than `DecimalType`.
case (t1: IntegralType, t2: DecimalType) =>
findTightestCommonType(DecimalType.forType(t1), t2)
case (t1: DecimalType, t2: IntegralType) =>
findTightestCommonType(t1, DecimalType.forType(t2))
// Double support larger range than fixed decimal, DecimalType.Maximum should be enough
// in most case, also have better precision.
case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
Some(DoubleType)
case (t1: DecimalType, t2: DecimalType) =>
val scale = math.max(t1.scale, t2.scale)
val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
if (range + scale > 38) {
// DecimalType can't support precision > 38
Some(DoubleType)
} else {
Some(DecimalType(range + scale, scale))
}
case _ => None case _ => None
} }
} }
......
~ decimal field has integer, integer and decimal values. The last value cannot fit to a long
~ long field has integer, long and integer values.
~ double field has double, double and decimal values.
decimal,long,double
1,1,0.1
1,9223372036854775807,1.0
92233720368547758070,1,92233720368547758070
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution.datasources.csv package org.apache.spark.sql.execution.datasources.csv
import java.text.SimpleDateFormat
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
...@@ -35,6 +33,11 @@ class CSVInferSchemaSuite extends SparkFunSuite { ...@@ -35,6 +33,11 @@ class CSVInferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType) assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType)
assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType) assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType)
assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType) assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType)
val textValueOne = Long.MaxValue.toString + "0"
val decimalValueOne = new java.math.BigDecimal(textValueOne)
val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale)
assert(CSVInferSchema.inferField(NullType, textValueOne, options) == expectedTypeOne)
} }
test("String fields types are inferred correctly from other types") { test("String fields types are inferred correctly from other types") {
...@@ -49,6 +52,11 @@ class CSVInferSchemaSuite extends SparkFunSuite { ...@@ -49,6 +52,11 @@ class CSVInferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType) assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType)
assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType) assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType)
assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType) assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType)
val textValueOne = Long.MaxValue.toString + "0"
val decimalValueOne = new java.math.BigDecimal(textValueOne)
val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale)
assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == expectedTypeOne)
} }
test("Timestamp field types are inferred correctly via custom data format") { test("Timestamp field types are inferred correctly via custom data format") {
...@@ -94,6 +102,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { ...@@ -94,6 +102,7 @@ class CSVInferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType)
assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType)
assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType) assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType)
assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1))
} }
test("Merging Nulltypes should yield Nulltype.") { test("Merging Nulltypes should yield Nulltype.") {
......
...@@ -43,6 +43,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { ...@@ -43,6 +43,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val commentsFile = "comments.csv" private val commentsFile = "comments.csv"
private val disableCommentsFile = "disable_comments.csv" private val disableCommentsFile = "disable_comments.csv"
private val boolFile = "bool.csv" private val boolFile = "bool.csv"
private val decimalFile = "decimal.csv"
private val simpleSparseFile = "simple_sparse.csv" private val simpleSparseFile = "simple_sparse.csv"
private val numbersFile = "numbers.csv" private val numbersFile = "numbers.csv"
private val datesFile = "dates.csv" private val datesFile = "dates.csv"
...@@ -133,6 +134,20 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { ...@@ -133,6 +134,20 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(result.schema === expectedSchema) assert(result.schema === expectedSchema)
} }
test("test inferring decimals") {
val result = sqlContext.read
.format("csv")
.option("comment", "~")
.option("header", "true")
.option("inferSchema", "true")
.load(testFile(decimalFile))
val expectedSchema = StructType(List(
StructField("decimal", DecimalType(20, 0), nullable = true),
StructField("long", LongType, nullable = true),
StructField("double", DoubleType, nullable = true)))
assert(result.schema === expectedSchema)
}
test("test with alternative delimiter and quote") { test("test with alternative delimiter and quote") {
val cars = spark.read val cars = spark.read
.format("csv") .format("csv")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment