diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 0255f53113e48167dab5c64d87340ae3f4c5e73a..3800d53c02f4c84dbe672210e09595f6ba9140f3 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -835,6 +835,190 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, return fromString(sb.toString()); } + private int getDigit(byte b) { + if (b >= '0' && b <= '9') { + return b - '0'; + } + throw new NumberFormatException(toString()); + } + + /** + * Parses this UTF8String to long. + * + * Note that, in this method we accumulate the result in negative format, and convert it to + * positive format at the end, if this string is not started with '-'. This is because min value + * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and + * Integer.MIN_VALUE is '-2147483648'. + * + * This code is mostly copied from LazyLong.parseLong in Hive. + */ + public long toLong() { + if (numBytes == 0) { + throw new NumberFormatException("Empty string"); + } + + byte b = getByte(0); + final boolean negative = b == '-'; + int offset = 0; + if (negative || b == '+') { + offset++; + if (numBytes == 1) { + throw new NumberFormatException(toString()); + } + } + + final byte separator = '.'; + final int radix = 10; + final long stopValue = Long.MIN_VALUE / radix; + long result = 0; + + while (offset < numBytes) { + b = getByte(offset); + offset++; + if (b == separator) { + // We allow decimals and will return a truncated integral in that case. + // Therefore we won't throw an exception here (checking the fractional + // part happens below.) + break; + } + + int digit = getDigit(b); + // We are going to process the new digit and accumulate the result. However, before doing + // this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then + // result * 10 will definitely be smaller than minValue, and we can stop and throw exception. + if (result < stopValue) { + throw new NumberFormatException(toString()); + } + + result = result * radix - digit; + // Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we + // can just use `result > 0` to check overflow. If result overflows, we should stop and throw + // exception. + if (result > 0) { + throw new NumberFormatException(toString()); + } + } + + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well formed. + while (offset < numBytes) { + if (getDigit(getByte(offset)) == -1) { + throw new NumberFormatException(toString()); + } + offset++; + } + + if (!negative) { + result = -result; + if (result < 0) { + throw new NumberFormatException(toString()); + } + } + + return result; + } + + /** + * Parses this UTF8String to int. + * + * Note that, in this method we accumulate the result in negative format, and convert it to + * positive format at the end, if this string is not started with '-'. This is because min value + * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and + * Integer.MIN_VALUE is '-2147483648'. + * + * This code is mostly copied from LazyInt.parseInt in Hive. + * + * Note that, this method is almost same as `toLong`, but we leave it duplicated for performance + * reasons, like Hive does. + */ + public int toInt() { + if (numBytes == 0) { + throw new NumberFormatException("Empty string"); + } + + byte b = getByte(0); + final boolean negative = b == '-'; + int offset = 0; + if (negative || b == '+') { + offset++; + if (numBytes == 1) { + throw new NumberFormatException(toString()); + } + } + + final byte separator = '.'; + final int radix = 10; + final int stopValue = Integer.MIN_VALUE / radix; + int result = 0; + + while (offset < numBytes) { + b = getByte(offset); + offset++; + if (b == separator) { + // We allow decimals and will return a truncated integral in that case. + // Therefore we won't throw an exception here (checking the fractional + // part happens below.) + break; + } + + int digit = getDigit(b); + // We are going to process the new digit and accumulate the result. However, before doing + // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then + // result * 10 will definitely be smaller than minValue, and we can stop and throw exception. + if (result < stopValue) { + throw new NumberFormatException(toString()); + } + + result = result * radix - digit; + // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), + // we can just use `result > 0` to check overflow. If result overflows, we should stop and + // throw exception. + if (result > 0) { + throw new NumberFormatException(toString()); + } + } + + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well formed. + while (offset < numBytes) { + if (getDigit(getByte(offset)) == -1) { + throw new NumberFormatException(toString()); + } + offset++; + } + + if (!negative) { + result = -result; + if (result < 0) { + throw new NumberFormatException(toString()); + } + } + + return result; + } + + public short toShort() { + int intValue = toInt(); + short result = (short) intValue; + if (result != intValue) { + throw new NumberFormatException(toString()); + } + + return result; + } + + public byte toByte() { + int intValue = toInt(); + byte result = (byte) intValue; + if (result != intValue) { + throw new NumberFormatException(toString()); + } + + return result; + } + @Override public String toString() { return new String(getBytes(), StandardCharsets.UTF_8); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index cd73f9c897bcd98150257aed10df8ed2fd8fb1af..5f72fa8536e6108b299d34994b4ba64ee892f3a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -51,7 +51,6 @@ object TypeCoercion { PromoteStrings :: DecimalPrecision :: BooleanEquality :: - StringToIntegralCasts :: FunctionArgumentConversion :: CaseWhenCoercion :: IfCoercion :: @@ -428,21 +427,6 @@ object TypeCoercion { } } - /** - * When encountering a cast from a string representing a valid fractional number to an integral - * type the jvm will throw a `java.lang.NumberFormatException`. Hive, in contrast, returns the - * truncated version of this number. - */ - object StringToIntegralCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e - - case Cast(e @ StringType(), t: IntegralType) => - Cast(Cast(e, DecimalType.forType(LongType)), t) - } - } - /** * This ensure that the types for various functions are as expected. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 741730e3e0be45e66afec1487f6adc17c926bf35..14e275bf88e61ae09fd412e0dbb404ff1367706e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -247,7 +247,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toString.toLong catch { + buildCast[UTF8String](_, s => try s.toLong catch { case _: NumberFormatException => null }) case BooleanType => @@ -263,7 +263,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toString.toInt catch { + buildCast[UTF8String](_, s => try s.toInt catch { case _: NumberFormatException => null }) case BooleanType => @@ -279,7 +279,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toString.toShort catch { + buildCast[UTF8String](_, s => try s.toShort catch { case _: NumberFormatException => null }) case BooleanType => @@ -295,7 +295,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toString.toByte catch { + buildCast[UTF8String](_, s => try s.toByte catch { case _: NumberFormatException => null }) case BooleanType => @@ -498,7 +498,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w s""" boolean $resultNull = $childNull; ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)}; - if (!${childNull}) { + if (!$childNull) { ${cast(childPrim, resultPrim, resultNull)} } """ @@ -705,7 +705,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s""" try { - $evPrim = Byte.valueOf($c.toString()); + $evPrim = $c.toByte(); } catch (java.lang.NumberFormatException e) { $evNull = true; } @@ -727,7 +727,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s""" try { - $evPrim = Short.valueOf($c.toString()); + $evPrim = $c.toShort(); } catch (java.lang.NumberFormatException e) { $evNull = true; } @@ -749,7 +749,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s""" try { - $evPrim = Integer.valueOf($c.toString()); + $evPrim = $c.toInt(); } catch (java.lang.NumberFormatException e) { $evNull = true; } @@ -771,7 +771,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s""" try { - $evPrim = Long.valueOf($c.toString()); + $evPrim = $c.toLong(); } catch (java.lang.NumberFormatException e) { $evNull = true; } diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/cast.sql new file mode 100644 index 0000000000000000000000000000000000000000..5fae571945e4173c9e540775d65b30325064f794 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql @@ -0,0 +1,43 @@ +-- cast string representing a valid fractional number to integral should truncate the number +SELECT CAST('1.23' AS int); +SELECT CAST('1.23' AS long); +SELECT CAST('-4.56' AS int); +SELECT CAST('-4.56' AS long); + +-- cast string which are not numbers to integral should return null +SELECT CAST('abc' AS int); +SELECT CAST('abc' AS long); + +-- cast string representing a very large number to integral should return null +SELECT CAST('1234567890123' AS int); +SELECT CAST('12345678901234567890123' AS long); + +-- cast empty string to integral should return null +SELECT CAST('' AS int); +SELECT CAST('' AS long); + +-- cast null to integral should return null +SELECT CAST(NULL AS int); +SELECT CAST(NULL AS long); + +-- cast invalid decimal string to integral should return null +SELECT CAST('123.a' AS int); +SELECT CAST('123.a' AS long); + +-- '-2147483648' is the smallest int value +SELECT CAST('-2147483648' AS int); +SELECT CAST('-2147483649' AS int); + +-- '2147483647' is the largest int value +SELECT CAST('2147483647' AS int); +SELECT CAST('2147483648' AS int); + +-- '-9223372036854775808' is the smallest long value +SELECT CAST('-9223372036854775808' AS long); +SELECT CAST('-9223372036854775809' AS long); + +-- '9223372036854775807' is the largest long value +SELECT CAST('9223372036854775807' AS long); +SELECT CAST('9223372036854775808' AS long); + +-- TODO: migrate all cast tests here. diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/cast.sql.out new file mode 100644 index 0000000000000000000000000000000000000000..bfa29d7d2d597658f2ffc442ed1399db1615c135 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out @@ -0,0 +1,178 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 22 + + +-- !query 0 +SELECT CAST('1.23' AS int) +-- !query 0 schema +struct<CAST(1.23 AS INT):int> +-- !query 0 output +1 + + +-- !query 1 +SELECT CAST('1.23' AS long) +-- !query 1 schema +struct<CAST(1.23 AS BIGINT):bigint> +-- !query 1 output +1 + + +-- !query 2 +SELECT CAST('-4.56' AS int) +-- !query 2 schema +struct<CAST(-4.56 AS INT):int> +-- !query 2 output +-4 + + +-- !query 3 +SELECT CAST('-4.56' AS long) +-- !query 3 schema +struct<CAST(-4.56 AS BIGINT):bigint> +-- !query 3 output +-4 + + +-- !query 4 +SELECT CAST('abc' AS int) +-- !query 4 schema +struct<CAST(abc AS INT):int> +-- !query 4 output +NULL + + +-- !query 5 +SELECT CAST('abc' AS long) +-- !query 5 schema +struct<CAST(abc AS BIGINT):bigint> +-- !query 5 output +NULL + + +-- !query 6 +SELECT CAST('1234567890123' AS int) +-- !query 6 schema +struct<CAST(1234567890123 AS INT):int> +-- !query 6 output +NULL + + +-- !query 7 +SELECT CAST('12345678901234567890123' AS long) +-- !query 7 schema +struct<CAST(12345678901234567890123 AS BIGINT):bigint> +-- !query 7 output +NULL + + +-- !query 8 +SELECT CAST('' AS int) +-- !query 8 schema +struct<CAST( AS INT):int> +-- !query 8 output +NULL + + +-- !query 9 +SELECT CAST('' AS long) +-- !query 9 schema +struct<CAST( AS BIGINT):bigint> +-- !query 9 output +NULL + + +-- !query 10 +SELECT CAST(NULL AS int) +-- !query 10 schema +struct<CAST(NULL AS INT):int> +-- !query 10 output +NULL + + +-- !query 11 +SELECT CAST(NULL AS long) +-- !query 11 schema +struct<CAST(NULL AS BIGINT):bigint> +-- !query 11 output +NULL + + +-- !query 12 +SELECT CAST('123.a' AS int) +-- !query 12 schema +struct<CAST(123.a AS INT):int> +-- !query 12 output +NULL + + +-- !query 13 +SELECT CAST('123.a' AS long) +-- !query 13 schema +struct<CAST(123.a AS BIGINT):bigint> +-- !query 13 output +NULL + + +-- !query 14 +SELECT CAST('-2147483648' AS int) +-- !query 14 schema +struct<CAST(-2147483648 AS INT):int> +-- !query 14 output +-2147483648 + + +-- !query 15 +SELECT CAST('-2147483649' AS int) +-- !query 15 schema +struct<CAST(-2147483649 AS INT):int> +-- !query 15 output +NULL + + +-- !query 16 +SELECT CAST('2147483647' AS int) +-- !query 16 schema +struct<CAST(2147483647 AS INT):int> +-- !query 16 output +2147483647 + + +-- !query 17 +SELECT CAST('2147483648' AS int) +-- !query 17 schema +struct<CAST(2147483648 AS INT):int> +-- !query 17 output +NULL + + +-- !query 18 +SELECT CAST('-9223372036854775808' AS long) +-- !query 18 schema +struct<CAST(-9223372036854775808 AS BIGINT):bigint> +-- !query 18 output +-9223372036854775808 + + +-- !query 19 +SELECT CAST('-9223372036854775809' AS long) +-- !query 19 schema +struct<CAST(-9223372036854775809 AS BIGINT):bigint> +-- !query 19 output +NULL + + +-- !query 20 +SELECT CAST('9223372036854775807' AS long) +-- !query 20 schema +struct<CAST(9223372036854775807 AS BIGINT):bigint> +-- !query 20 output +9223372036854775807 + + +-- !query 21 +SELECT CAST('9223372036854775808' AS long) +-- !query 21 schema +struct<CAST(9223372036854775808 AS BIGINT):bigint> +-- !query 21 output +NULL