diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index c63aae9d83855ea7c7f575ea5ceefc0ccd71505e..88c608add140f149bcb56edcf5077e4a369a38f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -215,84 +215,121 @@ private[csv] object CSVInferSchema { } private[csv] object CSVTypeCast { + // A `ValueConverter` is responsible for converting the given value to a desired type. + private type ValueConverter = String => Any /** - * Casts given string datum to specified type. - * Currently we do not support complex types (ArrayType, MapType, StructType). + * Create converters which cast each given string datum to each specified type in given schema. + * Currently, we do not support complex types (`ArrayType`, `MapType`, `StructType`). * - * For string types, this is simply the datum. For other types. + * For string types, this is simply the datum. + * For other types, this is converted into the value according to the type. * For other nullable types, returns null if it is null or equals to the value specified * in `nullValue` option. * - * @param datum string value - * @param name field name in schema. - * @param castType data type to cast `datum` into. - * @param nullable nullability for the field. + * @param schema schema that contains data types to cast the given value into. * @param options CSV options. */ - def castTo( + def makeConverters( + schema: StructType, + options: CSVOptions = CSVOptions()): Array[ValueConverter] = { + schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + } + + /** + * Create a converter which converts the string value to a value according to a desired type. + */ + def makeConverter( + name: String, + dataType: DataType, + nullable: Boolean = true, + options: CSVOptions = CSVOptions()): ValueConverter = dataType match { + case _: ByteType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toByte) + + case _: ShortType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toShort) + + case _: IntegerType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toInt) + + case _: LongType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toLong) + + case _: FloatType => (d: String) => + nullSafeDatum(d, name, nullable, options) { + case options.nanValue => Float.NaN + case options.negativeInf => Float.NegativeInfinity + case options.positiveInf => Float.PositiveInfinity + case datum => + Try(datum.toFloat) + .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).floatValue()) + } + + case _: DoubleType => (d: String) => + nullSafeDatum(d, name, nullable, options) { + case options.nanValue => Double.NaN + case options.negativeInf => Double.NegativeInfinity + case options.positiveInf => Double.PositiveInfinity + case datum => + Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).doubleValue()) + } + + case _: BooleanType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toBoolean) + + case dt: DecimalType => (d: String) => + nullSafeDatum(d, name, nullable, options) { datum => + val value = new BigDecimal(datum.replaceAll(",", "")) + Decimal(value, dt.precision, dt.scale) + } + + case _: TimestampType => (d: String) => + nullSafeDatum(d, name, nullable, options) { datum => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + Try(options.timestampFormat.parse(datum).getTime * 1000L) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.stringToTime(datum).getTime * 1000L + } + } + + case _: DateType => (d: String) => + nullSafeDatum(d, name, nullable, options) { datum => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681.x + Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) + } + } + + case _: StringType => (d: String) => + nullSafeDatum(d, name, nullable, options)(UTF8String.fromString(_)) + + case udt: UserDefinedType[_] => (datum: String) => + makeConverter(name, udt.sqlType, nullable, options) + + case _ => throw new RuntimeException(s"Unsupported type: ${dataType.typeName}") + } + + private def nullSafeDatum( datum: String, name: String, - castType: DataType, - nullable: Boolean = true, - options: CSVOptions = CSVOptions()): Any = { - - // datum can be null if the number of fields found is less than the length of the schema + nullable: Boolean, + options: CSVOptions)(converter: ValueConverter): Any = { if (datum == options.nullValue || datum == null) { if (!nullable) { throw new RuntimeException(s"null value found but field $name is not nullable.") } null } else { - castType match { - case _: ByteType => datum.toByte - case _: ShortType => datum.toShort - case _: IntegerType => datum.toInt - case _: LongType => datum.toLong - case _: FloatType => - datum match { - case options.nanValue => Float.NaN - case options.negativeInf => Float.NegativeInfinity - case options.positiveInf => Float.PositiveInfinity - case _ => - Try(datum.toFloat) - .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).floatValue()) - } - case _: DoubleType => - datum match { - case options.nanValue => Double.NaN - case options.negativeInf => Double.NegativeInfinity - case options.positiveInf => Double.PositiveInfinity - case _ => - Try(datum.toDouble) - .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).doubleValue()) - } - case _: BooleanType => datum.toBoolean - case dt: DecimalType => - val value = new BigDecimal(datum.replaceAll(",", "")) - Decimal(value, dt.precision, dt.scale) - case _: TimestampType => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - Try(options.timestampFormat.parse(datum).getTime * 1000L) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - DateTimeUtils.stringToTime(datum).getTime * 1000L - } - case _: DateType => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681.x - Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) - } - case _: StringType => UTF8String.fromString(datum) - case udt: UserDefinedType[_] => castTo(datum, name, udt.sqlType, nullable, options) - case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") - } + converter.apply(datum) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index e4ce7a94be7df1874534fc6e02a3612314bdbe37..23c07eb630d3116646f45273144c125f10cafeee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -69,26 +69,26 @@ object CSVRelation extends Logging { schema: StructType, requiredColumns: Array[String], params: CSVOptions): (Array[String], Int) => Option[InternalRow] = { - val schemaFields = schema.fields val requiredFields = StructType(requiredColumns.map(schema(_))).fields val safeRequiredFields = if (params.dropMalformed) { // If `dropMalformed` is enabled, then it needs to parse all the values // so that we can decide which row is malformed. - requiredFields ++ schemaFields.filterNot(requiredFields.contains(_)) + requiredFields ++ schema.filterNot(requiredFields.contains(_)) } else { requiredFields } val safeRequiredIndices = new Array[Int](safeRequiredFields.length) - schemaFields.zipWithIndex.filter { - case (field, _) => safeRequiredFields.contains(field) - }.foreach { - case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index + schema.zipWithIndex.filter { case (field, _) => + safeRequiredFields.contains(field) + }.foreach { case (field, index) => + safeRequiredIndices(safeRequiredFields.indexOf(field)) = index } val requiredSize = requiredFields.length val row = new GenericInternalRow(requiredSize) + val converters = CSVTypeCast.makeConverters(schema, params) (tokens: Array[String], numMalformedRows) => { - if (params.dropMalformed && schemaFields.length != tokens.length) { + if (params.dropMalformed && schema.length != tokens.length) { if (numMalformedRows < params.maxMalformedLogPerPartition) { logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") } @@ -98,14 +98,14 @@ object CSVRelation extends Logging { "found on this partition. Malformed records from now on will not be logged.") } None - } else if (params.failFast && schemaFields.length != tokens.length) { + } else if (params.failFast && schema.length != tokens.length) { throw new RuntimeException(s"Malformed line in FAILFAST mode: " + s"${tokens.mkString(params.delimiter.toString)}") } else { - val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.length) { - tokens ++ new Array[String](schemaFields.length - tokens.length) - } else if (params.permissive && schemaFields.length < tokens.length) { - tokens.take(schemaFields.length) + val indexSafeTokens = if (params.permissive && schema.length > tokens.length) { + tokens ++ new Array[String](schema.length - tokens.length) + } else if (params.permissive && schema.length < tokens.length) { + tokens.take(schema.length) } else { tokens } @@ -114,20 +114,14 @@ object CSVRelation extends Logging { var subIndex: Int = 0 while (subIndex < safeRequiredIndices.length) { index = safeRequiredIndices(subIndex) - val field = schemaFields(index) // It anyway needs to try to parse since it decides if this row is malformed // or not after trying to cast in `DROPMALFORMED` mode even if the casted // value is not stored in the row. - val value = CSVTypeCast.castTo( - indexSafeTokens(index), - field.name, - field.dataType, - field.nullable, - params) + val value = converters(index).apply(indexSafeTokens(index)) if (subIndex < requiredSize) { row(subIndex) = value } - subIndex = subIndex + 1 + subIndex += 1 } Some(row) } catch { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index 46333d12138fb61dae7b8bc6617be2036ff93106..ffd3d260bcb4036805c58305650cfd27a55c764b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -36,7 +36,7 @@ class CSVTypeCastSuite extends SparkFunSuite { stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => val decimalValue = new BigDecimal(decimalVal.toString) - assert(CSVTypeCast.castTo(strVal, "_1", decimalType) === + assert(CSVTypeCast.makeConverter("_1", decimalType).apply(strVal) === Decimal(decimalValue, decimalType.precision, decimalType.scale)) } } @@ -66,92 +66,81 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("Nullable types are handled") { - assertNull( - CSVTypeCast.castTo("-", "_1", ByteType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", "_1", ShortType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", "_1", IntegerType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", "_1", LongType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", "_1", FloatType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", "_1", DoubleType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", "_1", BooleanType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", "_1", DecimalType.DoubleDecimal, true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", "_1", TimestampType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", "_1", DateType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", "_1", StringType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo(null, "_1", IntegerType, nullable = true, CSVOptions("nullValue", "-"))) - - // casting a null to not nullable field should throw an exception. - var message = intercept[RuntimeException] { - CSVTypeCast.castTo(null, "_1", IntegerType, nullable = false, CSVOptions("nullValue", "-")) - }.getMessage - assert(message.contains("null value found but field _1 is not nullable.")) - - message = intercept[RuntimeException] { - CSVTypeCast.castTo("-", "_1", StringType, nullable = false, CSVOptions("nullValue", "-")) - }.getMessage - assert(message.contains("null value found but field _1 is not nullable.")) - } - - test("String type should also respect `nullValue`") { - assertNull( - CSVTypeCast.castTo("", "_1", StringType, nullable = true, CSVOptions())) + val types = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, + BooleanType, DecimalType.DoubleDecimal, TimestampType, DateType, StringType) + + // Nullable field with nullValue option. + types.foreach { t => + // Tests that a custom nullValue. + val converter = + CSVTypeCast.makeConverter("_1", t, nullable = true, CSVOptions("nullValue", "-")) + assertNull(converter.apply("-")) + assertNull(converter.apply(null)) + + // Tests that the default nullValue is empty string. + assertNull(CSVTypeCast.makeConverter("_1", t, nullable = true).apply("")) + } - assert( - CSVTypeCast.castTo("", "_1", StringType, nullable = true, CSVOptions("nullValue", "null")) == - UTF8String.fromString("")) - assert( - CSVTypeCast.castTo("", "_1", StringType, nullable = false, CSVOptions("nullValue", "null")) == - UTF8String.fromString("")) + // Not nullable field with nullValue option. + types.foreach { t => + // Casts a null to not nullable field should throw an exception. + val converter = + CSVTypeCast.makeConverter("_1", t, nullable = false, CSVOptions("nullValue", "-")) + var message = intercept[RuntimeException] { + converter.apply("-") + }.getMessage + assert(message.contains("null value found but field _1 is not nullable.")) + message = intercept[RuntimeException] { + converter.apply(null) + }.getMessage + assert(message.contains("null value found but field _1 is not nullable.")) + } - assertNull( - CSVTypeCast.castTo(null, "_1", StringType, nullable = true, CSVOptions("nullValue", "null"))) + // If nullValue is different with empty string, then, empty string should not be casted into + // null. + Seq(true, false).foreach { b => + val converter = + CSVTypeCast.makeConverter("_1", StringType, nullable = b, CSVOptions("nullValue", "null")) + assert(converter.apply("") == UTF8String.fromString("")) + } } test("Throws exception for empty string with non null type") { val exception = intercept[RuntimeException]{ - CSVTypeCast.castTo("", "_1", IntegerType, nullable = false, CSVOptions()) + CSVTypeCast.makeConverter("_1", IntegerType, nullable = false, CSVOptions()).apply("") } assert(exception.getMessage.contains("null value found but field _1 is not nullable.")) } test("Types are cast correctly") { - assert(CSVTypeCast.castTo("10", "_1", ByteType) == 10) - assert(CSVTypeCast.castTo("10", "_1", ShortType) == 10) - assert(CSVTypeCast.castTo("10", "_1", IntegerType) == 10) - assert(CSVTypeCast.castTo("10", "_1", LongType) == 10) - assert(CSVTypeCast.castTo("1.00", "_1", FloatType) == 1.0) - assert(CSVTypeCast.castTo("1.00", "_1", DoubleType) == 1.0) - assert(CSVTypeCast.castTo("true", "_1", BooleanType) == true) + assert(CSVTypeCast.makeConverter("_1", ByteType).apply("10") == 10) + assert(CSVTypeCast.makeConverter("_1", ShortType).apply("10") == 10) + assert(CSVTypeCast.makeConverter("_1", IntegerType).apply("10") == 10) + assert(CSVTypeCast.makeConverter("_1", LongType).apply("10") == 10) + assert(CSVTypeCast.makeConverter("_1", FloatType).apply("1.00") == 1.0) + assert(CSVTypeCast.makeConverter("_1", DoubleType).apply("1.00") == 1.0) + assert(CSVTypeCast.makeConverter("_1", BooleanType).apply("true") == true) val timestampsOptions = CSVOptions("timestampFormat", "dd/MM/yyyy hh:mm") val customTimestamp = "31/01/2015 00:00" val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime val castedTimestamp = - CSVTypeCast.castTo(customTimestamp, "_1", TimestampType, nullable = true, timestampsOptions) + CSVTypeCast.makeConverter("_1", TimestampType, nullable = true, timestampsOptions) + .apply(customTimestamp) assert(castedTimestamp == expectedTime * 1000L) val customDate = "31/01/2015" val dateOptions = CSVOptions("dateFormat", "dd/MM/yyyy") val expectedDate = dateOptions.dateFormat.parse(customDate).getTime val castedDate = - CSVTypeCast.castTo(customTimestamp, "_1", DateType, nullable = true, dateOptions) + CSVTypeCast.makeConverter("_1", DateType, nullable = true, dateOptions) + .apply(customTimestamp) assert(castedDate == DateTimeUtils.millisToDays(expectedDate)) val timestamp = "2015-01-01 00:00:00" - assert(CSVTypeCast.castTo(timestamp, "_1", TimestampType) == + assert(CSVTypeCast.makeConverter("_1", TimestampType).apply(timestamp) == DateTimeUtils.stringToTime(timestamp).getTime * 1000L) - assert(CSVTypeCast.castTo("2015-01-01", "_1", DateType) == + assert(CSVTypeCast.makeConverter("_1", DateType).apply("2015-01-01") == DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime)) } @@ -159,16 +148,18 @@ class CSVTypeCastSuite extends SparkFunSuite { val originalLocale = Locale.getDefault try { Locale.setDefault(new Locale("fr", "FR")) - assert(CSVTypeCast.castTo("1,00", "_1", FloatType) == 100.0) // Would parse as 1.0 in fr-FR - assert(CSVTypeCast.castTo("1,00", "_1", DoubleType) == 100.0) + // Would parse as 1.0 in fr-FR + assert(CSVTypeCast.makeConverter("_1", FloatType).apply("1,00") == 100.0) + assert(CSVTypeCast.makeConverter("_1", DoubleType).apply("1,00") == 100.0) } finally { Locale.setDefault(originalLocale) } } test("Float NaN values are parsed correctly") { - val floatVal: Float = CSVTypeCast.castTo( - "nn", "_1", FloatType, nullable = true, CSVOptions("nanValue", "nn")).asInstanceOf[Float] + val floatVal: Float = CSVTypeCast.makeConverter( + "_1", FloatType, nullable = true, CSVOptions("nanValue", "nn") + ).apply("nn").asInstanceOf[Float] // Java implements the IEEE-754 floating point standard which guarantees that any comparison // against NaN will return false (except != which returns true) @@ -176,34 +167,37 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("Double NaN values are parsed correctly") { - val doubleVal: Double = CSVTypeCast.castTo( - "-", "_1", DoubleType, nullable = true, CSVOptions("nanValue", "-")).asInstanceOf[Double] + val doubleVal: Double = CSVTypeCast.makeConverter( + "_1", DoubleType, nullable = true, CSVOptions("nanValue", "-") + ).apply("-").asInstanceOf[Double] assert(doubleVal.isNaN) } test("Float infinite values can be parsed") { - val floatVal1 = CSVTypeCast.castTo( - "max", "_1", FloatType, nullable = true, CSVOptions("negativeInf", "max")).asInstanceOf[Float] + val floatVal1 = CSVTypeCast.makeConverter( + "_1", FloatType, nullable = true, CSVOptions("negativeInf", "max") + ).apply("max").asInstanceOf[Float] assert(floatVal1 == Float.NegativeInfinity) - val floatVal2 = CSVTypeCast.castTo( - "max", "_1", FloatType, nullable = true, CSVOptions("positiveInf", "max")).asInstanceOf[Float] + val floatVal2 = CSVTypeCast.makeConverter( + "_1", FloatType, nullable = true, CSVOptions("positiveInf", "max") + ).apply("max").asInstanceOf[Float] assert(floatVal2 == Float.PositiveInfinity) } test("Double infinite values can be parsed") { - val doubleVal1 = CSVTypeCast.castTo( - "max", "_1", DoubleType, nullable = true, CSVOptions("negativeInf", "max") - ).asInstanceOf[Double] + val doubleVal1 = CSVTypeCast.makeConverter( + "_1", DoubleType, nullable = true, CSVOptions("negativeInf", "max") + ).apply("max").asInstanceOf[Double] assert(doubleVal1 == Double.NegativeInfinity) - val doubleVal2 = CSVTypeCast.castTo( - "max", "_1", DoubleType, nullable = true, CSVOptions("positiveInf", "max") - ).asInstanceOf[Double] + val doubleVal2 = CSVTypeCast.makeConverter( + "_1", DoubleType, nullable = true, CSVOptions("positiveInf", "max") + ).apply("max").asInstanceOf[Double] assert(doubleVal2 == Double.PositiveInfinity) }