diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 5c1908d55576a4bd0985bbf50ba78cff18cc69e1..438215e8e6e37dd57eb1a04c1da82c536227c8ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -640,7 +640,7 @@ case class StringSplit(str: Expression, pattern: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -649,58 +649,59 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def foldable: Boolean = str.foldable && pos.foldable && len.foldable override def nullable: Boolean = str.nullable || pos.nullable || len.nullable - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") - } - if (str.dataType == BinaryType) str.dataType else StringType - } + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) override def children: Seq[Expression] = str :: pos :: len :: Nil - @inline - def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { - // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and - // negative indices for start positions. If a start index i is greater than 0, it - // refers to element i-1 in the sequence. If a start index i is less than 0, it refers - // to the -ith element before the end of the sequence. If a start index i is 0, it - // refers to the first element. - - val start = startPos match { - case pos if pos > 0 => pos - 1 - case neg if neg < 0 => length() + neg - case _ => 0 - } - - val end = sliceLen match { - case max if max == Integer.MAX_VALUE => max - case x => start + x + override def eval(input: InternalRow): Any = { + val stringEval = str.eval(input) + if (stringEval != null) { + val posEval = pos.eval(input) + if (posEval != null) { + val lenEval = len.eval(input) + if (lenEval != null) { + stringEval.asInstanceOf[UTF8String] + .substringSQL(posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int]) + } else { + null + } + } else { + null + } + } else { + null } - - (start, end) } - override def eval(input: InternalRow): Any = { - val string = str.eval(input) - val po = pos.eval(input) - val ln = len.eval(input) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val strGen = str.gen(ctx) + val posGen = pos.gen(ctx) + val lenGen = len.gen(ctx) - if ((string == null) || (po == null) || (ln == null)) { - null - } else { - val start = po.asInstanceOf[Int] - val length = ln.asInstanceOf[Int] - string match { - case ba: Array[Byte] => - val (st, end) = slicePos(start, length, () => ba.length) - ba.slice(st, end) - case s: UTF8String => - val (st, end) = slicePos(start, length, () => s.numChars()) - s.substring(st, end) + val start = ctx.freshName("start") + val end = ctx.freshName("end") + + s""" + ${strGen.code} + boolean ${ev.isNull} = ${strGen.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${posGen.code} + if (!${posGen.isNull}) { + ${lenGen.code} + if (!${lenGen.isNull}) { + ${ev.primitive} = ${strGen.primitive} + .substringSQL(${posGen.primitive}, ${lenGen.primitive}); + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } } - } + """ } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index ed354f7f877f11d67570d9fd838f90b857f037b0..946d355f1fc283e32e6949acaaedf6a978a60d12 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -165,6 +165,18 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { return fromBytes(bytes); } + public UTF8String substringSQL(int pos, int length) { + // Information regarding the pos calculation: + // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and + // negative indices for start positions. If a start index i is greater than 0, it + // refers to element i-1 in the sequence. If a start index i is less than 0, it refers + // to the -ith element before the end of the sequence. If a start index i is 0, it + // refers to the first element. + int start = (pos > 0) ? pos -1 : ((pos < 0) ? numChars() + pos : 0); + int end = (length == Integer.MAX_VALUE) ? Integer.MAX_VALUE : start + length; + return substring(start, end); + } + /** * Returns whether this contains `substring` or not. */ diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 1f5572c509bdb65509b8b4bce9058910a0eaa6d5..e2a5628ff4d9329797c71b7de6e6ab77bfe2fd7f 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -272,6 +272,25 @@ public class UTF8StringSuite { fromString("æ•°æ®ç –头").rpad(12, fromString("å™è¡Œè€…"))); } + @Test + public void substringSQL() { + UTF8String e = fromString("example"); + assertEquals(e.substringSQL(0, 2), fromString("ex")); + assertEquals(e.substringSQL(1, 2), fromString("ex")); + assertEquals(e.substringSQL(0, 7), fromString("example")); + assertEquals(e.substringSQL(1, 2), fromString("ex")); + assertEquals(e.substringSQL(0, 100), fromString("example")); + assertEquals(e.substringSQL(1, 100), fromString("example")); + assertEquals(e.substringSQL(2, 2), fromString("xa")); + assertEquals(e.substringSQL(1, 6), fromString("exampl")); + assertEquals(e.substringSQL(2, 100), fromString("xample")); + assertEquals(e.substringSQL(0, 0), fromString("")); + assertEquals(e.substringSQL(100, 4), EMPTY_UTF8); + assertEquals(e.substringSQL(0, Integer.MAX_VALUE), fromString("example")); + assertEquals(e.substringSQL(1, Integer.MAX_VALUE), fromString("example")); + assertEquals(e.substringSQL(2, Integer.MAX_VALUE), fromString("xample")); + } + @Test public void split() { assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1),