diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index cfd66af18892b80bbee737d706569f5782cfed23..05c8d8ee15f668fce2d9d901d0811da319ab8f20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal -import java.text.{NumberFormat, SimpleDateFormat} +import java.text.NumberFormat import java.util.Locale import scala.util.control.Exception._ @@ -85,6 +85,7 @@ private[csv] object CSVInferSchema { case NullType => tryParseInteger(field, options) case IntegerType => tryParseInteger(field, options) case LongType => tryParseLong(field, options) + case _: DecimalType => tryParseDecimal(field, options) case DoubleType => tryParseDouble(field, options) case TimestampType => tryParseTimestamp(field, options) case BooleanType => tryParseBoolean(field, options) @@ -107,10 +108,28 @@ private[csv] object CSVInferSchema { if ((allCatch opt field.toLong).isDefined) { LongType } else { - tryParseDouble(field, options) + tryParseDecimal(field, options) } } + private def tryParseDecimal(field: String, options: CSVOptions): DataType = { + val decimalTry = allCatch opt { + // `BigDecimal` conversion can fail when the `field` is not a form of number. + val bigDecimal = new BigDecimal(field) + // Because many other formats do not support decimal, it reduces the cases for + // decimals by disallowing values having scale (eg. `1.1`). + if (bigDecimal.scale <= 0) { + // `DecimalType` conversion can fail when + // 1. The precision is bigger than 38. + // 2. scale is bigger than precision. + DecimalType(bigDecimal.precision, bigDecimal.scale) + } else { + tryParseDouble(field, options) + } + } + decimalTry.getOrElse(tryParseDouble(field, options)) + } + private def tryParseDouble(field: String, options: CSVOptions): DataType = { if ((allCatch opt field.toDouble).isDefined) { DoubleType @@ -170,6 +189,33 @@ private[csv] object CSVInferSchema { val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) Some(numericPrecedence(index)) + // These two cases below deal with when `DecimalType` is larger than `IntegralType`. + case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) => + Some(t2) + case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) => + Some(t1) + + // These two cases below deal with when `IntegralType` is larger than `DecimalType`. + case (t1: IntegralType, t2: DecimalType) => + findTightestCommonType(DecimalType.forType(t1), t2) + case (t1: DecimalType, t2: IntegralType) => + findTightestCommonType(t1, DecimalType.forType(t2)) + + // Double support larger range than fixed decimal, DecimalType.Maximum should be enough + // in most case, also have better precision. + case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => + Some(DoubleType) + + case (t1: DecimalType, t2: DecimalType) => + val scale = math.max(t1.scale, t2.scale) + val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) + if (range + scale > 38) { + // DecimalType can't support precision > 38 + Some(DoubleType) + } else { + Some(DecimalType(range + scale, scale)) + } + case _ => None } } diff --git a/sql/core/src/test/resources/decimal.csv b/sql/core/src/test/resources/decimal.csv new file mode 100644 index 0000000000000000000000000000000000000000..870f6aaf1bb4c3f0e3053348c696111dc767c9a3 --- /dev/null +++ b/sql/core/src/test/resources/decimal.csv @@ -0,0 +1,7 @@ +~ decimal field has integer, integer and decimal values. The last value cannot fit to a long +~ long field has integer, long and integer values. +~ double field has double, double and decimal values. +decimal,long,double +1,1,0.1 +1,9223372036854775807,1.0 +92233720368547758070,1,92233720368547758070 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index daf85be56f3d25e4b437385ecae58430de2f4d97..dbe3af49c90c36ec91956a97131951428ba6c186 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.datasources.csv -import java.text.SimpleDateFormat - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -35,6 +33,11 @@ class CSVInferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType) assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType) assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType) + + val textValueOne = Long.MaxValue.toString + "0" + val decimalValueOne = new java.math.BigDecimal(textValueOne) + val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) + assert(CSVInferSchema.inferField(NullType, textValueOne, options) == expectedTypeOne) } test("String fields types are inferred correctly from other types") { @@ -49,6 +52,11 @@ class CSVInferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType) assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType) assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType) + + val textValueOne = Long.MaxValue.toString + "0" + val decimalValueOne = new java.math.BigDecimal(textValueOne) + val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) + assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == expectedTypeOne) } test("Timestamp field types are inferred correctly via custom data format") { @@ -94,6 +102,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType) + assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1)) } test("Merging Nulltypes should yield Nulltype.") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index ae91e0f606eccf46756e913e37419671bf35e062..27d6dc9197d27fcd5f9dd255b4b6278d2b4e0ed7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -43,6 +43,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val commentsFile = "comments.csv" private val disableCommentsFile = "disable_comments.csv" private val boolFile = "bool.csv" + private val decimalFile = "decimal.csv" private val simpleSparseFile = "simple_sparse.csv" private val numbersFile = "numbers.csv" private val datesFile = "dates.csv" @@ -133,6 +134,20 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(result.schema === expectedSchema) } + test("test inferring decimals") { + val result = sqlContext.read + .format("csv") + .option("comment", "~") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(decimalFile)) + val expectedSchema = StructType(List( + StructField("decimal", DecimalType(20, 0), nullable = true), + StructField("long", LongType, nullable = true), + StructField("double", DoubleType, nullable = true))) + assert(result.schema === expectedSchema) + } + test("test with alternative delimiter and quote") { val cars = spark.read .format("csv")