Skip to content
Snippets Groups Projects
Commit c96d14ab authored by Tejas Patil's avatar Tejas Patil Committed by Wenchen Fan
Browse files

[SPARK-19843][SQL] UTF8String => (int / long) conversion expensive for invalid inputs

## What changes were proposed in this pull request?

Jira : https://issues.apache.org/jira/browse/SPARK-19843

Created wrapper classes (`IntWrapper`, `LongWrapper`) to wrap the result of parsing (which are primitive types). In case of problem in parsing, the method would return a boolean.

## How was this patch tested?

- Added new unit tests
- Ran a prod job which had conversion from string -> int and verified the outputs

## Performance

Tiny regression when all strings are valid integers

```
conversion to int:       Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
--------------------------------------------------------------------------------
trunk                         502 /  522         33.4          29.9       1.0X
SPARK-19843                   493 /  503         34.0          29.4       1.0X
```

Huge gain when all strings are invalid integers
```
conversion to int:      Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
-------------------------------------------------------------------------------
trunk                     33913 / 34219          0.5        2021.4       1.0X
SPARK-19843                  154 /  162        108.8           9.2     220.0X
```

Author: Tejas Patil <tejasp@fb.com>

Closes #17184 from tejasapatil/SPARK-19843_is_numeric_maybe.
parent 47b2f68a
No related branches found
No related tags found
No related merge requests found
...@@ -850,11 +850,8 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, ...@@ -850,11 +850,8 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
return fromString(sb.toString()); return fromString(sb.toString());
} }
private int getDigit(byte b) { public static class LongWrapper {
if (b >= '0' && b <= '9') { public long value = 0;
return b - '0';
}
throw new NumberFormatException(toString());
} }
/** /**
...@@ -862,14 +859,18 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, ...@@ -862,14 +859,18 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
* *
* Note that, in this method we accumulate the result in negative format, and convert it to * Note that, in this method we accumulate the result in negative format, and convert it to
* positive format at the end, if this string is not started with '-'. This is because min value * positive format at the end, if this string is not started with '-'. This is because min value
* is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and * is bigger than max value in digits, e.g. Long.MAX_VALUE is '9223372036854775807' and
* Integer.MIN_VALUE is '-2147483648'. * Long.MIN_VALUE is '-9223372036854775808'.
* *
* This code is mostly copied from LazyLong.parseLong in Hive. * This code is mostly copied from LazyLong.parseLong in Hive.
*
* @param toLongResult If a valid `long` was parsed from this UTF8String, then its value would
* be set in `toLongResult`
* @return true if the parsing was successful else false
*/ */
public long toLong() { public boolean toLong(LongWrapper toLongResult) {
if (numBytes == 0) { if (numBytes == 0) {
throw new NumberFormatException("Empty string"); return false;
} }
byte b = getByte(0); byte b = getByte(0);
...@@ -878,7 +879,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, ...@@ -878,7 +879,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
if (negative || b == '+') { if (negative || b == '+') {
offset++; offset++;
if (numBytes == 1) { if (numBytes == 1) {
throw new NumberFormatException(toString()); return false;
} }
} }
...@@ -897,20 +898,25 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, ...@@ -897,20 +898,25 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
break; break;
} }
int digit = getDigit(b); int digit;
if (b >= '0' && b <= '9') {
digit = b - '0';
} else {
return false;
}
// We are going to process the new digit and accumulate the result. However, before doing // We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then // this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then
// result * 10 will definitely be smaller than minValue, and we can stop and throw exception. // result * 10 will definitely be smaller than minValue, and we can stop.
if (result < stopValue) { if (result < stopValue) {
throw new NumberFormatException(toString()); return false;
} }
result = result * radix - digit; result = result * radix - digit;
// Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we // Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we
// can just use `result > 0` to check overflow. If result overflows, we should stop and throw // can just use `result > 0` to check overflow. If result overflows, we should stop.
// exception.
if (result > 0) { if (result > 0) {
throw new NumberFormatException(toString()); return false;
} }
} }
...@@ -918,8 +924,9 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, ...@@ -918,8 +924,9 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
// part will not change the number, but we will verify that the fractional part // part will not change the number, but we will verify that the fractional part
// is well formed. // is well formed.
while (offset < numBytes) { while (offset < numBytes) {
if (getDigit(getByte(offset)) == -1) { byte currentByte = getByte(offset);
throw new NumberFormatException(toString()); if (currentByte < '0' || currentByte > '9') {
return false;
} }
offset++; offset++;
} }
...@@ -927,11 +934,16 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, ...@@ -927,11 +934,16 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
if (!negative) { if (!negative) {
result = -result; result = -result;
if (result < 0) { if (result < 0) {
throw new NumberFormatException(toString()); return false;
} }
} }
return result; toLongResult.value = result;
return true;
}
public static class IntWrapper {
public int value = 0;
} }
/** /**
...@@ -946,10 +958,14 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, ...@@ -946,10 +958,14 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
* *
* Note that, this method is almost same as `toLong`, but we leave it duplicated for performance * Note that, this method is almost same as `toLong`, but we leave it duplicated for performance
* reasons, like Hive does. * reasons, like Hive does.
*
* @param intWrapper If a valid `int` was parsed from this UTF8String, then its value would
* be set in `intWrapper`
* @return true if the parsing was successful else false
*/ */
public int toInt() { public boolean toInt(IntWrapper intWrapper) {
if (numBytes == 0) { if (numBytes == 0) {
throw new NumberFormatException("Empty string"); return false;
} }
byte b = getByte(0); byte b = getByte(0);
...@@ -958,7 +974,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, ...@@ -958,7 +974,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
if (negative || b == '+') { if (negative || b == '+') {
offset++; offset++;
if (numBytes == 1) { if (numBytes == 1) {
throw new NumberFormatException(toString()); return false;
} }
} }
...@@ -977,20 +993,25 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, ...@@ -977,20 +993,25 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
break; break;
} }
int digit = getDigit(b); int digit;
if (b >= '0' && b <= '9') {
digit = b - '0';
} else {
return false;
}
// We are going to process the new digit and accumulate the result. However, before doing // We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then
// result * 10 will definitely be smaller than minValue, and we can stop and throw exception. // result * 10 will definitely be smaller than minValue, and we can stop
if (result < stopValue) { if (result < stopValue) {
throw new NumberFormatException(toString()); return false;
} }
result = result * radix - digit; result = result * radix - digit;
// Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix),
// we can just use `result > 0` to check overflow. If result overflows, we should stop and // we can just use `result > 0` to check overflow. If result overflows, we should stop
// throw exception.
if (result > 0) { if (result > 0) {
throw new NumberFormatException(toString()); return false;
} }
} }
...@@ -998,8 +1019,9 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, ...@@ -998,8 +1019,9 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
// part will not change the number, but we will verify that the fractional part // part will not change the number, but we will verify that the fractional part
// is well formed. // is well formed.
while (offset < numBytes) { while (offset < numBytes) {
if (getDigit(getByte(offset)) == -1) { byte currentByte = getByte(offset);
throw new NumberFormatException(toString()); if (currentByte < '0' || currentByte > '9') {
return false;
} }
offset++; offset++;
} }
...@@ -1007,31 +1029,33 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, ...@@ -1007,31 +1029,33 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
if (!negative) { if (!negative) {
result = -result; result = -result;
if (result < 0) { if (result < 0) {
throw new NumberFormatException(toString()); return false;
} }
} }
intWrapper.value = result;
return result; return true;
} }
public short toShort() { public boolean toShort(IntWrapper intWrapper) {
int intValue = toInt(); if (toInt(intWrapper)) {
short result = (short) intValue; int intValue = intWrapper.value;
if (result != intValue) { short result = (short) intValue;
throw new NumberFormatException(toString()); if (result == intValue) {
return true;
}
} }
return false;
return result;
} }
public byte toByte() { public boolean toByte(IntWrapper intWrapper) {
int intValue = toInt(); if (toInt(intWrapper)) {
byte result = (byte) intValue; int intValue = intWrapper.value;
if (result != intValue) { byte result = (byte) intValue;
throw new NumberFormatException(toString()); if (result == intValue) {
return true;
}
} }
return false;
return result;
} }
@Override @Override
......
...@@ -22,9 +22,7 @@ import java.io.IOException; ...@@ -22,9 +22,7 @@ import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Arrays; import java.util.*;
import java.util.HashMap;
import java.util.HashSet;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.Platform;
...@@ -608,4 +606,128 @@ public class UTF8StringSuite { ...@@ -608,4 +606,128 @@ public class UTF8StringSuite {
.writeTo(outputStream); .writeTo(outputStream);
assertEquals("大千世界", outputStream.toString("UTF-8")); assertEquals("大千世界", outputStream.toString("UTF-8"));
} }
@Test
public void testToShort() throws IOException {
Map<String, Short> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", (short) 1);
inputToExpectedOutput.put("+1", (short) 1);
inputToExpectedOutput.put("-1", (short) -1);
inputToExpectedOutput.put("0", (short) 0);
inputToExpectedOutput.put("1111.12345678901234567890", (short) 1111);
inputToExpectedOutput.put(String.valueOf(Short.MAX_VALUE), Short.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Short.MIN_VALUE), Short.MIN_VALUE);
Random rand = new Random();
for (int i = 0; i < 10; i++) {
short value = (short) rand.nextInt();
inputToExpectedOutput.put(String.valueOf(value), value);
}
IntWrapper wrapper = new IntWrapper();
for (Map.Entry<String, Short> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toShort(wrapper));
assertEquals((short) entry.getValue(), wrapper.value);
}
List<String> negativeInputs =
Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "3276700");
for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toShort(wrapper));
}
}
@Test
public void testToByte() throws IOException {
Map<String, Byte> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", (byte) 1);
inputToExpectedOutput.put("+1",(byte) 1);
inputToExpectedOutput.put("-1", (byte) -1);
inputToExpectedOutput.put("0", (byte) 0);
inputToExpectedOutput.put("111.12345678901234567890", (byte) 111);
inputToExpectedOutput.put(String.valueOf(Byte.MAX_VALUE), Byte.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Byte.MIN_VALUE), Byte.MIN_VALUE);
Random rand = new Random();
for (int i = 0; i < 10; i++) {
byte value = (byte) rand.nextInt();
inputToExpectedOutput.put(String.valueOf(value), value);
}
IntWrapper intWrapper = new IntWrapper();
for (Map.Entry<String, Byte> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toByte(intWrapper));
assertEquals((byte) entry.getValue(), intWrapper.value);
}
List<String> negativeInputs =
Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890");
for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toByte(intWrapper));
}
}
@Test
public void testToInt() throws IOException {
Map<String, Integer> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", 1);
inputToExpectedOutput.put("+1", 1);
inputToExpectedOutput.put("-1", -1);
inputToExpectedOutput.put("0", 0);
inputToExpectedOutput.put("11111.1234567", 11111);
inputToExpectedOutput.put(String.valueOf(Integer.MAX_VALUE), Integer.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Integer.MIN_VALUE), Integer.MIN_VALUE);
Random rand = new Random();
for (int i = 0; i < 10; i++) {
int value = rand.nextInt();
inputToExpectedOutput.put(String.valueOf(value), value);
}
IntWrapper intWrapper = new IntWrapper();
for (Map.Entry<String, Integer> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toInt(intWrapper));
assertEquals((int) entry.getValue(), intWrapper.value);
}
List<String> negativeInputs =
Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890");
for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toInt(intWrapper));
}
}
@Test
public void testToLong() throws IOException {
Map<String, Long> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", 1L);
inputToExpectedOutput.put("+1", 1L);
inputToExpectedOutput.put("-1", -1L);
inputToExpectedOutput.put("0", 0L);
inputToExpectedOutput.put("1076753423.12345678901234567890", 1076753423L);
inputToExpectedOutput.put(String.valueOf(Long.MAX_VALUE), Long.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Long.MIN_VALUE), Long.MIN_VALUE);
Random rand = new Random();
for (int i = 0; i < 10; i++) {
long value = rand.nextLong();
inputToExpectedOutput.put(String.valueOf(value), value);
}
LongWrapper wrapper = new LongWrapper();
for (Map.Entry<String, Long> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toLong(wrapper));
assertEquals((long) entry.getValue(), wrapper.value);
}
List<String> negativeInputs = Arrays.asList("", " ", "null", "NULL", "\n", "~1212121",
"1234567890123456789012345678901234");
for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper));
}
}
} }
...@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ ...@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}
object Cast { object Cast {
...@@ -277,9 +277,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String ...@@ -277,9 +277,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// LongConverter // LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match { private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType => case StringType =>
buildCast[UTF8String](_, s => try s.toLong catch { val result = new LongWrapper()
case _: NumberFormatException => null buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
})
case BooleanType => case BooleanType =>
buildCast[Boolean](_, b => if (b) 1L else 0L) buildCast[Boolean](_, b => if (b) 1L else 0L)
case DateType => case DateType =>
...@@ -293,9 +292,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String ...@@ -293,9 +292,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// IntConverter // IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match { private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType => case StringType =>
buildCast[UTF8String](_, s => try s.toInt catch { val result = new IntWrapper()
case _: NumberFormatException => null buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
})
case BooleanType => case BooleanType =>
buildCast[Boolean](_, b => if (b) 1 else 0) buildCast[Boolean](_, b => if (b) 1 else 0)
case DateType => case DateType =>
...@@ -309,8 +307,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String ...@@ -309,8 +307,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// ShortConverter // ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match { private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType => case StringType =>
buildCast[UTF8String](_, s => try s.toShort catch { val result = new IntWrapper()
case _: NumberFormatException => null buildCast[UTF8String](_, s => if (s.toShort(result)) {
result.value.toShort
} else {
null
}) })
case BooleanType => case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
...@@ -325,8 +326,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String ...@@ -325,8 +326,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// ByteConverter // ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match { private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType => case StringType =>
buildCast[UTF8String](_, s => try s.toByte catch { val result = new IntWrapper()
case _: NumberFormatException => null buildCast[UTF8String](_, s => if (s.toByte(result)) {
result.value.toByte
} else {
null
}) })
case BooleanType => case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
...@@ -503,11 +507,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String ...@@ -503,11 +507,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case TimestampType => castToTimestampCode(from, ctx) case TimestampType => castToTimestampCode(from, ctx)
case CalendarIntervalType => castToIntervalCode(from) case CalendarIntervalType => castToIntervalCode(from)
case BooleanType => castToBooleanCode(from) case BooleanType => castToBooleanCode(from)
case ByteType => castToByteCode(from) case ByteType => castToByteCode(from, ctx)
case ShortType => castToShortCode(from) case ShortType => castToShortCode(from, ctx)
case IntegerType => castToIntCode(from) case IntegerType => castToIntCode(from, ctx)
case FloatType => castToFloatCode(from) case FloatType => castToFloatCode(from)
case LongType => castToLongCode(from) case LongType => castToLongCode(from, ctx)
case DoubleType => castToDoubleCode(from) case DoubleType => castToDoubleCode(from)
case array: ArrayType => case array: ArrayType =>
...@@ -734,13 +738,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String ...@@ -734,13 +738,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = $c != 0;" (c, evPrim, evNull) => s"$evPrim = $c != 0;"
} }
private[this] def castToByteCode(from: DataType): CastFunction = from match { private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType => case StringType =>
val wrapper = ctx.freshName("wrapper")
ctx.addMutableState("UTF8String.IntWrapper", wrapper,
s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) => (c, evPrim, evNull) =>
s""" s"""
try { if ($c.toByte($wrapper)) {
$evPrim = $c.toByte(); $evPrim = (byte) $wrapper.value;
} catch (java.lang.NumberFormatException e) { } else {
$evNull = true; $evNull = true;
} }
""" """
...@@ -756,13 +763,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String ...@@ -756,13 +763,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (byte) $c;" (c, evPrim, evNull) => s"$evPrim = (byte) $c;"
} }
private[this] def castToShortCode(from: DataType): CastFunction = from match { private[this] def castToShortCode(
from: DataType,
ctx: CodegenContext): CastFunction = from match {
case StringType => case StringType =>
val wrapper = ctx.freshName("wrapper")
ctx.addMutableState("UTF8String.IntWrapper", wrapper,
s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) => (c, evPrim, evNull) =>
s""" s"""
try { if ($c.toShort($wrapper)) {
$evPrim = $c.toShort(); $evPrim = (short) $wrapper.value;
} catch (java.lang.NumberFormatException e) { } else {
$evNull = true; $evNull = true;
} }
""" """
...@@ -778,13 +790,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String ...@@ -778,13 +790,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (short) $c;" (c, evPrim, evNull) => s"$evPrim = (short) $c;"
} }
private[this] def castToIntCode(from: DataType): CastFunction = from match { private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType => case StringType =>
val wrapper = ctx.freshName("wrapper")
ctx.addMutableState("UTF8String.IntWrapper", wrapper,
s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) => (c, evPrim, evNull) =>
s""" s"""
try { if ($c.toInt($wrapper)) {
$evPrim = $c.toInt(); $evPrim = $wrapper.value;
} catch (java.lang.NumberFormatException e) { } else {
$evNull = true; $evNull = true;
} }
""" """
...@@ -800,13 +815,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String ...@@ -800,13 +815,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (int) $c;" (c, evPrim, evNull) => s"$evPrim = (int) $c;"
} }
private[this] def castToLongCode(from: DataType): CastFunction = from match { private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType => case StringType =>
val wrapper = ctx.freshName("wrapper")
ctx.addMutableState("UTF8String.LongWrapper", wrapper,
s"$wrapper = new UTF8String.LongWrapper();")
(c, evPrim, evNull) => (c, evPrim, evNull) =>
s""" s"""
try { if ($c.toLong($wrapper)) {
$evPrim = $c.toLong(); $evPrim = $wrapper.value;
} catch (java.lang.NumberFormatException e) { } else {
$evNull = true; $evNull = true;
} }
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment