Skip to content
Snippets Groups Projects
Commit c96d14ab authored by Tejas Patil's avatar Tejas Patil Committed by Wenchen Fan
Browse files

[SPARK-19843][SQL] UTF8String => (int / long) conversion expensive for invalid inputs

## What changes were proposed in this pull request?

Jira : https://issues.apache.org/jira/browse/SPARK-19843

Created wrapper classes (`IntWrapper`, `LongWrapper`) to wrap the result of parsing (which are primitive types). In case of problem in parsing, the method would return a boolean.

## How was this patch tested?

- Added new unit tests
- Ran a prod job which had conversion from string -> int and verified the outputs

## Performance

Tiny regression when all strings are valid integers

```
conversion to int:       Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
--------------------------------------------------------------------------------
trunk                         502 /  522         33.4          29.9       1.0X
SPARK-19843                   493 /  503         34.0          29.4       1.0X
```

Huge gain when all strings are invalid integers
```
conversion to int:      Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
-------------------------------------------------------------------------------
trunk                     33913 / 34219          0.5        2021.4       1.0X
SPARK-19843                  154 /  162        108.8           9.2     220.0X
```

Author: Tejas Patil <tejasp@fb.com>

Closes #17184 from tejasapatil/SPARK-19843_is_numeric_maybe.
parent 47b2f68a
No related branches found
No related tags found
No related merge requests found
......@@ -850,11 +850,8 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
return fromString(sb.toString());
}
private int getDigit(byte b) {
if (b >= '0' && b <= '9') {
return b - '0';
}
throw new NumberFormatException(toString());
public static class LongWrapper {
public long value = 0;
}
/**
......@@ -862,14 +859,18 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
*
* Note that, in this method we accumulate the result in negative format, and convert it to
* positive format at the end, if this string is not started with '-'. This is because min value
* is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and
* Integer.MIN_VALUE is '-2147483648'.
* is bigger than max value in digits, e.g. Long.MAX_VALUE is '9223372036854775807' and
* Long.MIN_VALUE is '-9223372036854775808'.
*
* This code is mostly copied from LazyLong.parseLong in Hive.
*
* @param toLongResult If a valid `long` was parsed from this UTF8String, then its value would
* be set in `toLongResult`
* @return true if the parsing was successful else false
*/
public long toLong() {
public boolean toLong(LongWrapper toLongResult) {
if (numBytes == 0) {
throw new NumberFormatException("Empty string");
return false;
}
byte b = getByte(0);
......@@ -878,7 +879,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
if (negative || b == '+') {
offset++;
if (numBytes == 1) {
throw new NumberFormatException(toString());
return false;
}
}
......@@ -897,20 +898,25 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
break;
}
int digit = getDigit(b);
int digit;
if (b >= '0' && b <= '9') {
digit = b - '0';
} else {
return false;
}
// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then
// result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
// result * 10 will definitely be smaller than minValue, and we can stop.
if (result < stopValue) {
throw new NumberFormatException(toString());
return false;
}
result = result * radix - digit;
// Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we
// can just use `result > 0` to check overflow. If result overflows, we should stop and throw
// exception.
// can just use `result > 0` to check overflow. If result overflows, we should stop.
if (result > 0) {
throw new NumberFormatException(toString());
return false;
}
}
......@@ -918,8 +924,9 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
// part will not change the number, but we will verify that the fractional part
// is well formed.
while (offset < numBytes) {
if (getDigit(getByte(offset)) == -1) {
throw new NumberFormatException(toString());
byte currentByte = getByte(offset);
if (currentByte < '0' || currentByte > '9') {
return false;
}
offset++;
}
......@@ -927,11 +934,16 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
if (!negative) {
result = -result;
if (result < 0) {
throw new NumberFormatException(toString());
return false;
}
}
return result;
toLongResult.value = result;
return true;
}
public static class IntWrapper {
public int value = 0;
}
/**
......@@ -946,10 +958,14 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
*
* Note that, this method is almost same as `toLong`, but we leave it duplicated for performance
* reasons, like Hive does.
*
* @param intWrapper If a valid `int` was parsed from this UTF8String, then its value would
* be set in `intWrapper`
* @return true if the parsing was successful else false
*/
public int toInt() {
public boolean toInt(IntWrapper intWrapper) {
if (numBytes == 0) {
throw new NumberFormatException("Empty string");
return false;
}
byte b = getByte(0);
......@@ -958,7 +974,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
if (negative || b == '+') {
offset++;
if (numBytes == 1) {
throw new NumberFormatException(toString());
return false;
}
}
......@@ -977,20 +993,25 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
break;
}
int digit = getDigit(b);
int digit;
if (b >= '0' && b <= '9') {
digit = b - '0';
} else {
return false;
}
// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then
// result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
// result * 10 will definitely be smaller than minValue, and we can stop
if (result < stopValue) {
throw new NumberFormatException(toString());
return false;
}
result = result * radix - digit;
// Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix),
// we can just use `result > 0` to check overflow. If result overflows, we should stop and
// throw exception.
// we can just use `result > 0` to check overflow. If result overflows, we should stop
if (result > 0) {
throw new NumberFormatException(toString());
return false;
}
}
......@@ -998,8 +1019,9 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
// part will not change the number, but we will verify that the fractional part
// is well formed.
while (offset < numBytes) {
if (getDigit(getByte(offset)) == -1) {
throw new NumberFormatException(toString());
byte currentByte = getByte(offset);
if (currentByte < '0' || currentByte > '9') {
return false;
}
offset++;
}
......@@ -1007,31 +1029,33 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
if (!negative) {
result = -result;
if (result < 0) {
throw new NumberFormatException(toString());
return false;
}
}
return result;
intWrapper.value = result;
return true;
}
public short toShort() {
int intValue = toInt();
short result = (short) intValue;
if (result != intValue) {
throw new NumberFormatException(toString());
public boolean toShort(IntWrapper intWrapper) {
if (toInt(intWrapper)) {
int intValue = intWrapper.value;
short result = (short) intValue;
if (result == intValue) {
return true;
}
}
return result;
return false;
}
public byte toByte() {
int intValue = toInt();
byte result = (byte) intValue;
if (result != intValue) {
throw new NumberFormatException(toString());
public boolean toByte(IntWrapper intWrapper) {
if (toInt(intWrapper)) {
int intValue = intWrapper.value;
byte result = (byte) intValue;
if (result == intValue) {
return true;
}
}
return result;
return false;
}
@Override
......
......@@ -22,9 +22,7 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.*;
import com.google.common.collect.ImmutableMap;
import org.apache.spark.unsafe.Platform;
......@@ -608,4 +606,128 @@ public class UTF8StringSuite {
.writeTo(outputStream);
assertEquals("大千世界", outputStream.toString("UTF-8"));
}
@Test
public void testToShort() throws IOException {
Map<String, Short> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", (short) 1);
inputToExpectedOutput.put("+1", (short) 1);
inputToExpectedOutput.put("-1", (short) -1);
inputToExpectedOutput.put("0", (short) 0);
inputToExpectedOutput.put("1111.12345678901234567890", (short) 1111);
inputToExpectedOutput.put(String.valueOf(Short.MAX_VALUE), Short.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Short.MIN_VALUE), Short.MIN_VALUE);
Random rand = new Random();
for (int i = 0; i < 10; i++) {
short value = (short) rand.nextInt();
inputToExpectedOutput.put(String.valueOf(value), value);
}
IntWrapper wrapper = new IntWrapper();
for (Map.Entry<String, Short> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toShort(wrapper));
assertEquals((short) entry.getValue(), wrapper.value);
}
List<String> negativeInputs =
Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "3276700");
for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toShort(wrapper));
}
}
@Test
public void testToByte() throws IOException {
Map<String, Byte> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", (byte) 1);
inputToExpectedOutput.put("+1",(byte) 1);
inputToExpectedOutput.put("-1", (byte) -1);
inputToExpectedOutput.put("0", (byte) 0);
inputToExpectedOutput.put("111.12345678901234567890", (byte) 111);
inputToExpectedOutput.put(String.valueOf(Byte.MAX_VALUE), Byte.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Byte.MIN_VALUE), Byte.MIN_VALUE);
Random rand = new Random();
for (int i = 0; i < 10; i++) {
byte value = (byte) rand.nextInt();
inputToExpectedOutput.put(String.valueOf(value), value);
}
IntWrapper intWrapper = new IntWrapper();
for (Map.Entry<String, Byte> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toByte(intWrapper));
assertEquals((byte) entry.getValue(), intWrapper.value);
}
List<String> negativeInputs =
Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890");
for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toByte(intWrapper));
}
}
@Test
public void testToInt() throws IOException {
Map<String, Integer> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", 1);
inputToExpectedOutput.put("+1", 1);
inputToExpectedOutput.put("-1", -1);
inputToExpectedOutput.put("0", 0);
inputToExpectedOutput.put("11111.1234567", 11111);
inputToExpectedOutput.put(String.valueOf(Integer.MAX_VALUE), Integer.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Integer.MIN_VALUE), Integer.MIN_VALUE);
Random rand = new Random();
for (int i = 0; i < 10; i++) {
int value = rand.nextInt();
inputToExpectedOutput.put(String.valueOf(value), value);
}
IntWrapper intWrapper = new IntWrapper();
for (Map.Entry<String, Integer> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toInt(intWrapper));
assertEquals((int) entry.getValue(), intWrapper.value);
}
List<String> negativeInputs =
Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890");
for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toInt(intWrapper));
}
}
@Test
public void testToLong() throws IOException {
Map<String, Long> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", 1L);
inputToExpectedOutput.put("+1", 1L);
inputToExpectedOutput.put("-1", -1L);
inputToExpectedOutput.put("0", 0L);
inputToExpectedOutput.put("1076753423.12345678901234567890", 1076753423L);
inputToExpectedOutput.put(String.valueOf(Long.MAX_VALUE), Long.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Long.MIN_VALUE), Long.MIN_VALUE);
Random rand = new Random();
for (int i = 0; i < 10; i++) {
long value = rand.nextLong();
inputToExpectedOutput.put(String.valueOf(value), value);
}
LongWrapper wrapper = new LongWrapper();
for (Map.Entry<String, Long> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toLong(wrapper));
assertEquals((long) entry.getValue(), wrapper.value);
}
List<String> negativeInputs = Arrays.asList("", " ", "null", "NULL", "\n", "~1212121",
"1234567890123456789012345678901234");
for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper));
}
}
}
......@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}
object Cast {
......@@ -277,9 +277,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try s.toLong catch {
case _: NumberFormatException => null
})
val result = new LongWrapper()
buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1L else 0L)
case DateType =>
......@@ -293,9 +292,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try s.toInt catch {
case _: NumberFormatException => null
})
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1 else 0)
case DateType =>
......@@ -309,8 +307,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try s.toShort catch {
case _: NumberFormatException => null
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toShort(result)) {
result.value.toShort
} else {
null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
......@@ -325,8 +326,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try s.toByte catch {
case _: NumberFormatException => null
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toByte(result)) {
result.value.toByte
} else {
null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
......@@ -503,11 +507,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case TimestampType => castToTimestampCode(from, ctx)
case CalendarIntervalType => castToIntervalCode(from)
case BooleanType => castToBooleanCode(from)
case ByteType => castToByteCode(from)
case ShortType => castToShortCode(from)
case IntegerType => castToIntCode(from)
case ByteType => castToByteCode(from, ctx)
case ShortType => castToShortCode(from, ctx)
case IntegerType => castToIntCode(from, ctx)
case FloatType => castToFloatCode(from)
case LongType => castToLongCode(from)
case LongType => castToLongCode(from, ctx)
case DoubleType => castToDoubleCode(from)
case array: ArrayType =>
......@@ -734,13 +738,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = $c != 0;"
}
private[this] def castToByteCode(from: DataType): CastFunction = from match {
private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
val wrapper = ctx.freshName("wrapper")
ctx.addMutableState("UTF8String.IntWrapper", wrapper,
s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) =>
s"""
try {
$evPrim = $c.toByte();
} catch (java.lang.NumberFormatException e) {
if ($c.toByte($wrapper)) {
$evPrim = (byte) $wrapper.value;
} else {
$evNull = true;
}
"""
......@@ -756,13 +763,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (byte) $c;"
}
private[this] def castToShortCode(from: DataType): CastFunction = from match {
private[this] def castToShortCode(
from: DataType,
ctx: CodegenContext): CastFunction = from match {
case StringType =>
val wrapper = ctx.freshName("wrapper")
ctx.addMutableState("UTF8String.IntWrapper", wrapper,
s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) =>
s"""
try {
$evPrim = $c.toShort();
} catch (java.lang.NumberFormatException e) {
if ($c.toShort($wrapper)) {
$evPrim = (short) $wrapper.value;
} else {
$evNull = true;
}
"""
......@@ -778,13 +790,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (short) $c;"
}
private[this] def castToIntCode(from: DataType): CastFunction = from match {
private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
val wrapper = ctx.freshName("wrapper")
ctx.addMutableState("UTF8String.IntWrapper", wrapper,
s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) =>
s"""
try {
$evPrim = $c.toInt();
} catch (java.lang.NumberFormatException e) {
if ($c.toInt($wrapper)) {
$evPrim = $wrapper.value;
} else {
$evNull = true;
}
"""
......@@ -800,13 +815,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (int) $c;"
}
private[this] def castToLongCode(from: DataType): CastFunction = from match {
private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
val wrapper = ctx.freshName("wrapper")
ctx.addMutableState("UTF8String.LongWrapper", wrapper,
s"$wrapper = new UTF8String.LongWrapper();")
(c, evPrim, evNull) =>
s"""
try {
$evPrim = $c.toLong();
} catch (java.lang.NumberFormatException e) {
if ($c.toLong($wrapper)) {
$evPrim = $wrapper.value;
} else {
$evNull = true;
}
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment