Skip to content
Snippets Groups Projects
Commit 560b355c authored by Tarek Auel's avatar Tarek Auel Committed by Davies Liu
Browse files

[SPARK-9157] [SQL] codegen substring

https://issues.apache.org/jira/browse/SPARK-9157

Author: Tarek Auel <tarek.auel@googlemail.com>

Closes #7534 from tarekauel/SPARK-9157 and squashes the following commits:

e65e3e9 [Tarek Auel] [SPARK-9157] indent fix
44e89f8 [Tarek Auel] [SPARK-9157] use EMPTY_UTF8
37d54c4 [Tarek Auel] Merge branch 'master' into SPARK-9157
60732ea [Tarek Auel] [SPARK-9157] created substringSQL in UTF8String
18c3576 [Tarek Auel] [SPARK-9157][SQL] remove slice pos
1a2e611 [Tarek Auel] [SPARK-9157][SQL] codegen substring
parent c032b0bf
No related branches found
No related tags found
No related merge requests found
......@@ -640,7 +640,7 @@ case class StringSplit(str: Expression, pattern: Expression)
* Defined for String and Binary types.
*/
case class Substring(str: Expression, pos: Expression, len: Expression)
extends Expression with ImplicitCastInputTypes with CodegenFallback {
extends Expression with ImplicitCastInputTypes {
def this(str: Expression, pos: Expression) = {
this(str, pos, Literal(Integer.MAX_VALUE))
......@@ -649,58 +649,59 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
override def foldable: Boolean = str.foldable && pos.foldable && len.foldable
override def nullable: Boolean = str.nullable || pos.nullable || len.nullable
override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved")
}
if (str.dataType == BinaryType) str.dataType else StringType
}
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
override def children: Seq[Expression] = str :: pos :: len :: Nil
@inline
def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = {
// Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
// negative indices for start positions. If a start index i is greater than 0, it
// refers to element i-1 in the sequence. If a start index i is less than 0, it refers
// to the -ith element before the end of the sequence. If a start index i is 0, it
// refers to the first element.
val start = startPos match {
case pos if pos > 0 => pos - 1
case neg if neg < 0 => length() + neg
case _ => 0
}
val end = sliceLen match {
case max if max == Integer.MAX_VALUE => max
case x => start + x
override def eval(input: InternalRow): Any = {
val stringEval = str.eval(input)
if (stringEval != null) {
val posEval = pos.eval(input)
if (posEval != null) {
val lenEval = len.eval(input)
if (lenEval != null) {
stringEval.asInstanceOf[UTF8String]
.substringSQL(posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int])
} else {
null
}
} else {
null
}
} else {
null
}
(start, end)
}
override def eval(input: InternalRow): Any = {
val string = str.eval(input)
val po = pos.eval(input)
val ln = len.eval(input)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val strGen = str.gen(ctx)
val posGen = pos.gen(ctx)
val lenGen = len.gen(ctx)
if ((string == null) || (po == null) || (ln == null)) {
null
} else {
val start = po.asInstanceOf[Int]
val length = ln.asInstanceOf[Int]
string match {
case ba: Array[Byte] =>
val (st, end) = slicePos(start, length, () => ba.length)
ba.slice(st, end)
case s: UTF8String =>
val (st, end) = slicePos(start, length, () => s.numChars())
s.substring(st, end)
val start = ctx.freshName("start")
val end = ctx.freshName("end")
s"""
${strGen.code}
boolean ${ev.isNull} = ${strGen.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${posGen.code}
if (!${posGen.isNull}) {
${lenGen.code}
if (!${lenGen.isNull}) {
${ev.primitive} = ${strGen.primitive}
.substringSQL(${posGen.primitive}, ${lenGen.primitive});
} else {
${ev.isNull} = true;
}
} else {
${ev.isNull} = true;
}
}
}
"""
}
}
......
......@@ -165,6 +165,18 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
return fromBytes(bytes);
}
public UTF8String substringSQL(int pos, int length) {
// Information regarding the pos calculation:
// Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
// negative indices for start positions. If a start index i is greater than 0, it
// refers to element i-1 in the sequence. If a start index i is less than 0, it refers
// to the -ith element before the end of the sequence. If a start index i is 0, it
// refers to the first element.
int start = (pos > 0) ? pos -1 : ((pos < 0) ? numChars() + pos : 0);
int end = (length == Integer.MAX_VALUE) ? Integer.MAX_VALUE : start + length;
return substring(start, end);
}
/**
* Returns whether this contains `substring` or not.
*/
......
......@@ -272,6 +272,25 @@ public class UTF8StringSuite {
fromString("数据砖头").rpad(12, fromString("孙行者")));
}
@Test
public void substringSQL() {
UTF8String e = fromString("example");
assertEquals(e.substringSQL(0, 2), fromString("ex"));
assertEquals(e.substringSQL(1, 2), fromString("ex"));
assertEquals(e.substringSQL(0, 7), fromString("example"));
assertEquals(e.substringSQL(1, 2), fromString("ex"));
assertEquals(e.substringSQL(0, 100), fromString("example"));
assertEquals(e.substringSQL(1, 100), fromString("example"));
assertEquals(e.substringSQL(2, 2), fromString("xa"));
assertEquals(e.substringSQL(1, 6), fromString("exampl"));
assertEquals(e.substringSQL(2, 100), fromString("xample"));
assertEquals(e.substringSQL(0, 0), fromString(""));
assertEquals(e.substringSQL(100, 4), EMPTY_UTF8);
assertEquals(e.substringSQL(0, Integer.MAX_VALUE), fromString("example"));
assertEquals(e.substringSQL(1, Integer.MAX_VALUE), fromString("example"));
assertEquals(e.substringSQL(2, Integer.MAX_VALUE), fromString("xample"));
}
@Test
public void split() {
assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment