From c96d14abae5962a7b15239319c2a151b95f7db94 Mon Sep 17 00:00:00 2001 From: Tejas Patil <tejasp@fb.com> Date: Tue, 7 Mar 2017 20:19:30 -0800 Subject: [PATCH] [SPARK-19843][SQL] UTF8String => (int / long) conversion expensive for invalid inputs ## What changes were proposed in this pull request? Jira : https://issues.apache.org/jira/browse/SPARK-19843 Created wrapper classes (`IntWrapper`, `LongWrapper`) to wrap the result of parsing (which are primitive types). In case of problem in parsing, the method would return a boolean. ## How was this patch tested? - Added new unit tests - Ran a prod job which had conversion from string -> int and verified the outputs ## Performance Tiny regression when all strings are valid integers ``` conversion to int: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------- trunk 502 / 522 33.4 29.9 1.0X SPARK-19843 493 / 503 34.0 29.4 1.0X ``` Huge gain when all strings are invalid integers ``` conversion to int: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------- trunk 33913 / 34219 0.5 2021.4 1.0X SPARK-19843 154 / 162 108.8 9.2 220.0X ``` Author: Tejas Patil <tejasp@fb.com> Closes #17184 from tejasapatil/SPARK-19843_is_numeric_maybe. --- .../apache/spark/unsafe/types/UTF8String.java | 120 +++++++++------- .../spark/unsafe/types/UTF8StringSuite.java | 128 +++++++++++++++++- .../spark/sql/catalyst/expressions/Cast.scala | 81 ++++++----- 3 files changed, 247 insertions(+), 82 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 10a7cb1d06..7abe0fa80a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -850,11 +850,8 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, return fromString(sb.toString()); } - private int getDigit(byte b) { - if (b >= '0' && b <= '9') { - return b - '0'; - } - throw new NumberFormatException(toString()); + public static class LongWrapper { + public long value = 0; } /** @@ -862,14 +859,18 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, * * Note that, in this method we accumulate the result in negative format, and convert it to * positive format at the end, if this string is not started with '-'. This is because min value - * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and - * Integer.MIN_VALUE is '-2147483648'. + * is bigger than max value in digits, e.g. Long.MAX_VALUE is '9223372036854775807' and + * Long.MIN_VALUE is '-9223372036854775808'. * * This code is mostly copied from LazyLong.parseLong in Hive. + * + * @param toLongResult If a valid `long` was parsed from this UTF8String, then its value would + * be set in `toLongResult` + * @return true if the parsing was successful else false */ - public long toLong() { + public boolean toLong(LongWrapper toLongResult) { if (numBytes == 0) { - throw new NumberFormatException("Empty string"); + return false; } byte b = getByte(0); @@ -878,7 +879,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, if (negative || b == '+') { offset++; if (numBytes == 1) { - throw new NumberFormatException(toString()); + return false; } } @@ -897,20 +898,25 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, break; } - int digit = getDigit(b); + int digit; + if (b >= '0' && b <= '9') { + digit = b - '0'; + } else { + return false; + } + // We are going to process the new digit and accumulate the result. However, before doing // this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then - // result * 10 will definitely be smaller than minValue, and we can stop and throw exception. + // result * 10 will definitely be smaller than minValue, and we can stop. if (result < stopValue) { - throw new NumberFormatException(toString()); + return false; } result = result * radix - digit; // Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we - // can just use `result > 0` to check overflow. If result overflows, we should stop and throw - // exception. + // can just use `result > 0` to check overflow. If result overflows, we should stop. if (result > 0) { - throw new NumberFormatException(toString()); + return false; } } @@ -918,8 +924,9 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, // part will not change the number, but we will verify that the fractional part // is well formed. while (offset < numBytes) { - if (getDigit(getByte(offset)) == -1) { - throw new NumberFormatException(toString()); + byte currentByte = getByte(offset); + if (currentByte < '0' || currentByte > '9') { + return false; } offset++; } @@ -927,11 +934,16 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, if (!negative) { result = -result; if (result < 0) { - throw new NumberFormatException(toString()); + return false; } } - return result; + toLongResult.value = result; + return true; + } + + public static class IntWrapper { + public int value = 0; } /** @@ -946,10 +958,14 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, * * Note that, this method is almost same as `toLong`, but we leave it duplicated for performance * reasons, like Hive does. + * + * @param intWrapper If a valid `int` was parsed from this UTF8String, then its value would + * be set in `intWrapper` + * @return true if the parsing was successful else false */ - public int toInt() { + public boolean toInt(IntWrapper intWrapper) { if (numBytes == 0) { - throw new NumberFormatException("Empty string"); + return false; } byte b = getByte(0); @@ -958,7 +974,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, if (negative || b == '+') { offset++; if (numBytes == 1) { - throw new NumberFormatException(toString()); + return false; } } @@ -977,20 +993,25 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, break; } - int digit = getDigit(b); + int digit; + if (b >= '0' && b <= '9') { + digit = b - '0'; + } else { + return false; + } + // We are going to process the new digit and accumulate the result. However, before doing // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then - // result * 10 will definitely be smaller than minValue, and we can stop and throw exception. + // result * 10 will definitely be smaller than minValue, and we can stop if (result < stopValue) { - throw new NumberFormatException(toString()); + return false; } result = result * radix - digit; // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), - // we can just use `result > 0` to check overflow. If result overflows, we should stop and - // throw exception. + // we can just use `result > 0` to check overflow. If result overflows, we should stop if (result > 0) { - throw new NumberFormatException(toString()); + return false; } } @@ -998,8 +1019,9 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, // part will not change the number, but we will verify that the fractional part // is well formed. while (offset < numBytes) { - if (getDigit(getByte(offset)) == -1) { - throw new NumberFormatException(toString()); + byte currentByte = getByte(offset); + if (currentByte < '0' || currentByte > '9') { + return false; } offset++; } @@ -1007,31 +1029,33 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, if (!negative) { result = -result; if (result < 0) { - throw new NumberFormatException(toString()); + return false; } } - - return result; + intWrapper.value = result; + return true; } - public short toShort() { - int intValue = toInt(); - short result = (short) intValue; - if (result != intValue) { - throw new NumberFormatException(toString()); + public boolean toShort(IntWrapper intWrapper) { + if (toInt(intWrapper)) { + int intValue = intWrapper.value; + short result = (short) intValue; + if (result == intValue) { + return true; + } } - - return result; + return false; } - public byte toByte() { - int intValue = toInt(); - byte result = (byte) intValue; - if (result != intValue) { - throw new NumberFormatException(toString()); + public boolean toByte(IntWrapper intWrapper) { + if (toInt(intWrapper)) { + int intValue = intWrapper.value; + byte result = (byte) intValue; + if (result == intValue) { + return true; + } } - - return result; + return false; } @Override diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 6f6e0ef0e4..c376371abd 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -22,9 +22,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.HashMap; -import java.util.HashSet; +import java.util.*; import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; @@ -608,4 +606,128 @@ public class UTF8StringSuite { .writeTo(outputStream); assertEquals("大åƒä¸–ç•Œ", outputStream.toString("UTF-8")); } + + @Test + public void testToShort() throws IOException { + Map<String, Short> inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", (short) 1); + inputToExpectedOutput.put("+1", (short) 1); + inputToExpectedOutput.put("-1", (short) -1); + inputToExpectedOutput.put("0", (short) 0); + inputToExpectedOutput.put("1111.12345678901234567890", (short) 1111); + inputToExpectedOutput.put(String.valueOf(Short.MAX_VALUE), Short.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Short.MIN_VALUE), Short.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + short value = (short) rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper wrapper = new IntWrapper(); + for (Map.Entry<String, Short> entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toShort(wrapper)); + assertEquals((short) entry.getValue(), wrapper.value); + } + + List<String> negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "3276700"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toShort(wrapper)); + } + } + + @Test + public void testToByte() throws IOException { + Map<String, Byte> inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", (byte) 1); + inputToExpectedOutput.put("+1",(byte) 1); + inputToExpectedOutput.put("-1", (byte) -1); + inputToExpectedOutput.put("0", (byte) 0); + inputToExpectedOutput.put("111.12345678901234567890", (byte) 111); + inputToExpectedOutput.put(String.valueOf(Byte.MAX_VALUE), Byte.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Byte.MIN_VALUE), Byte.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + byte value = (byte) rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper intWrapper = new IntWrapper(); + for (Map.Entry<String, Byte> entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toByte(intWrapper)); + assertEquals((byte) entry.getValue(), intWrapper.value); + } + + List<String> negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toByte(intWrapper)); + } + } + + @Test + public void testToInt() throws IOException { + Map<String, Integer> inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", 1); + inputToExpectedOutput.put("+1", 1); + inputToExpectedOutput.put("-1", -1); + inputToExpectedOutput.put("0", 0); + inputToExpectedOutput.put("11111.1234567", 11111); + inputToExpectedOutput.put(String.valueOf(Integer.MAX_VALUE), Integer.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Integer.MIN_VALUE), Integer.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + int value = rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper intWrapper = new IntWrapper(); + for (Map.Entry<String, Integer> entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toInt(intWrapper)); + assertEquals((int) entry.getValue(), intWrapper.value); + } + + List<String> negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toInt(intWrapper)); + } + } + + @Test + public void testToLong() throws IOException { + Map<String, Long> inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", 1L); + inputToExpectedOutput.put("+1", 1L); + inputToExpectedOutput.put("-1", -1L); + inputToExpectedOutput.put("0", 0L); + inputToExpectedOutput.put("1076753423.12345678901234567890", 1076753423L); + inputToExpectedOutput.put(String.valueOf(Long.MAX_VALUE), Long.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Long.MIN_VALUE), Long.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + long value = rand.nextLong(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + LongWrapper wrapper = new LongWrapper(); + for (Map.Entry<String, Long> entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toLong(wrapper)); + assertEquals((long) entry.getValue(), wrapper.value); + } + + List<String> negativeInputs = Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", + "1234567890123456789012345678901234"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper)); + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a36d3507d9..7c60f7d57a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} - +import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper} object Cast { @@ -277,9 +277,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toLong catch { - case _: NumberFormatException => null - }) + val result = new LongWrapper() + buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) case DateType => @@ -293,9 +292,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toInt catch { - case _: NumberFormatException => null - }) + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => @@ -309,8 +307,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toShort catch { - case _: NumberFormatException => null + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toShort(result)) { + result.value.toShort + } else { + null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) @@ -325,8 +326,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toByte catch { - case _: NumberFormatException => null + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toByte(result)) { + result.value.toByte + } else { + null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) @@ -503,11 +507,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case TimestampType => castToTimestampCode(from, ctx) case CalendarIntervalType => castToIntervalCode(from) case BooleanType => castToBooleanCode(from) - case ByteType => castToByteCode(from) - case ShortType => castToShortCode(from) - case IntegerType => castToIntCode(from) + case ByteType => castToByteCode(from, ctx) + case ShortType => castToShortCode(from, ctx) + case IntegerType => castToIntCode(from, ctx) case FloatType => castToFloatCode(from) - case LongType => castToLongCode(from) + case LongType => castToLongCode(from, ctx) case DoubleType => castToDoubleCode(from) case array: ArrayType => @@ -734,13 +738,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = $c != 0;" } - private[this] def castToByteCode(from: DataType): CastFunction = from match { + private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.IntWrapper", wrapper, + s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toByte(); - } catch (java.lang.NumberFormatException e) { + if ($c.toByte($wrapper)) { + $evPrim = (byte) $wrapper.value; + } else { $evNull = true; } """ @@ -756,13 +763,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = (byte) $c;" } - private[this] def castToShortCode(from: DataType): CastFunction = from match { + private[this] def castToShortCode( + from: DataType, + ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.IntWrapper", wrapper, + s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toShort(); - } catch (java.lang.NumberFormatException e) { + if ($c.toShort($wrapper)) { + $evPrim = (short) $wrapper.value; + } else { $evNull = true; } """ @@ -778,13 +790,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = (short) $c;" } - private[this] def castToIntCode(from: DataType): CastFunction = from match { + private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.IntWrapper", wrapper, + s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toInt(); - } catch (java.lang.NumberFormatException e) { + if ($c.toInt($wrapper)) { + $evPrim = $wrapper.value; + } else { $evNull = true; } """ @@ -800,13 +815,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = (int) $c;" } - private[this] def castToLongCode(from: DataType): CastFunction = from match { + private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.LongWrapper", wrapper, + s"$wrapper = new UTF8String.LongWrapper();") + (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toLong(); - } catch (java.lang.NumberFormatException e) { + if ($c.toLong($wrapper)) { + $evPrim = $wrapper.value; + } else { $evNull = true; } """ -- GitLab