Skip to content
Snippets Groups Projects
Commit 6b899438 authored by Wenchen Fan's avatar Wenchen Fan Committed by Reynold Xin
Browse files

[SPARK-8944][SQL] Support casting between IntervalType and StringType

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #7355 from cloud-fan/fromString and squashes the following commits:

3bbb9d6 [Wenchen Fan] fix code gen
7dab957 [Wenchen Fan] naming fix
0fbbe19 [Wenchen Fan] address comments
ac1f3d1 [Wenchen Fan] Support casting between IntervalType and StringType
parent 92540d22
No related branches found
No related tags found
No related merge requests found
...@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult ...@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.unsafe.types.{Interval, UTF8String}
object Cast { object Cast {
...@@ -55,6 +55,9 @@ object Cast { ...@@ -55,6 +55,9 @@ object Cast {
case (_, DateType) => true case (_, DateType) => true
case (StringType, IntervalType) => true
case (IntervalType, StringType) => true
case (StringType, _: NumericType) => true case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true case (BooleanType, _: NumericType) => true
case (DateType, _: NumericType) => true case (DateType, _: NumericType) => true
...@@ -232,6 +235,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w ...@@ -232,6 +235,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case _ => _ => null case _ => _ => null
} }
// IntervalConverter
private[this] def castToInterval(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => Interval.fromString(s.toString))
case _ => _ => null
}
// LongConverter // LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match { private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType => case StringType =>
...@@ -405,6 +415,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w ...@@ -405,6 +415,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case DateType => castToDate(from) case DateType => castToDate(from)
case decimal: DecimalType => castToDecimal(from, decimal) case decimal: DecimalType => castToDecimal(from, decimal)
case TimestampType => castToTimestamp(from) case TimestampType => castToTimestamp(from)
case IntervalType => castToInterval(from)
case BooleanType => castToBoolean(from) case BooleanType => castToBoolean(from)
case ByteType => castToByte(from) case ByteType => castToByte(from)
case ShortType => castToShort(from) case ShortType => castToShort(from)
...@@ -442,6 +453,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w ...@@ -442,6 +453,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (_, StringType) => case (_, StringType) =>
defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))") defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))")
case (StringType, IntervalType) =>
defineCodeGen(ctx, ev, c =>
s"org.apache.spark.unsafe.types.Interval.fromString($c.toString())")
// fallback for DecimalType, this must be before other numeric types // fallback for DecimalType, this must be before other numeric types
case (_, dt: DecimalType) => case (_, dt: DecimalType) =>
super.genCode(ctx, ev) super.genCode(ctx, ev)
......
...@@ -563,4 +563,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { ...@@ -563,4 +563,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
InternalRow(0L))) InternalRow(0L)))
} }
test("case between string and interval") {
import org.apache.spark.unsafe.types.Interval
checkEvaluation(Cast(Literal("interval -3 month 7 hours"), IntervalType),
new Interval(-3, 7 * Interval.MICROS_PER_HOUR))
checkEvaluation(Cast(Literal.create(
new Interval(15, -3 * Interval.MICROS_PER_DAY), IntervalType), StringType),
"interval 1 years 3 months -3 days")
}
} }
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
package org.apache.spark.unsafe.types; package org.apache.spark.unsafe.types;
import java.io.Serializable; import java.io.Serializable;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/** /**
* The internal representation of interval type. * The internal representation of interval type.
...@@ -30,6 +32,52 @@ public final class Interval implements Serializable { ...@@ -30,6 +32,52 @@ public final class Interval implements Serializable {
public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24; public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24;
public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7; public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7;
/**
* A function to generate regex which matches interval string's unit part like "3 years".
*
* First, we can leave out some units in interval string, and we only care about the value of
* unit, so here we use non-capturing group to wrap the actual regex.
* At the beginning of the actual regex, we should match spaces before the unit part.
* Next is the number part, starts with an optional "-" to represent negative value. We use
* capturing group to wrap this part as we need the value later.
* Finally is the unit name, ends with an optional "s".
*/
private static String unitRegex(String unit) {
return "(?:\\s+(-?\\d+)\\s+" + unit + "s?)?";
}
private static Pattern p = Pattern.compile("interval" + unitRegex("year") + unitRegex("month") +
unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") +
unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond"));
private static long toLong(String s) {
if (s == null) {
return 0;
} else {
return Long.valueOf(s);
}
}
public static Interval fromString(String s) {
if (s == null) {
return null;
}
Matcher m = p.matcher(s);
if (!m.matches() || s.equals("interval")) {
return null;
} else {
long months = toLong(m.group(1)) * 12 + toLong(m.group(2));
long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK;
microseconds += toLong(m.group(4)) * MICROS_PER_DAY;
microseconds += toLong(m.group(5)) * MICROS_PER_HOUR;
microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE;
microseconds += toLong(m.group(7)) * MICROS_PER_SECOND;
microseconds += toLong(m.group(8)) * MICROS_PER_MILLI;
microseconds += toLong(m.group(9));
return new Interval((int) months, microseconds);
}
}
public final int months; public final int months;
public final long microseconds; public final long microseconds;
......
...@@ -56,4 +56,50 @@ public class IntervalSuite { ...@@ -56,4 +56,50 @@ public class IntervalSuite {
i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123);
assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds"); assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds");
} }
@Test
public void fromStringTest() {
testSingleUnit("year", 3, 36, 0);
testSingleUnit("month", 3, 3, 0);
testSingleUnit("week", 3, 0, 3 * MICROS_PER_WEEK);
testSingleUnit("day", 3, 0, 3 * MICROS_PER_DAY);
testSingleUnit("hour", 3, 0, 3 * MICROS_PER_HOUR);
testSingleUnit("minute", 3, 0, 3 * MICROS_PER_MINUTE);
testSingleUnit("second", 3, 0, 3 * MICROS_PER_SECOND);
testSingleUnit("millisecond", 3, 0, 3 * MICROS_PER_MILLI);
testSingleUnit("microsecond", 3, 0, 3);
String input;
input = "interval -5 years 23 month";
Interval result = new Interval(-5 * 12 + 23, 0);
assertEquals(Interval.fromString(input), result);
// Error cases
input = "interval 3month 1 hour";
assertEquals(Interval.fromString(input), null);
input = "interval 3 moth 1 hour";
assertEquals(Interval.fromString(input), null);
input = "interval";
assertEquals(Interval.fromString(input), null);
input = "int";
assertEquals(Interval.fromString(input), null);
input = "";
assertEquals(Interval.fromString(input), null);
input = null;
assertEquals(Interval.fromString(input), null);
}
private void testSingleUnit(String unit, int number, int months, long microseconds) {
String input1 = "interval " + number + " " + unit;
String input2 = "interval " + number + " " + unit + "s";
Interval result = new Interval(months, microseconds);
assertEquals(Interval.fromString(input1), result);
assertEquals(Interval.fromString(input2), result);
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment