diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index a5682428b3d40d219ac87904394e79aa0cb7c5c3..5c1908d55576a4bd0985bbf50ba78cff18cc69e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -615,7 +615,7 @@ case class StringSpace(child: Expression) * Splits str around pat (pattern is a regular expression). */ case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ImplicitCastInputTypes with CodegenFallback { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = str override def right: Expression = pattern @@ -623,9 +623,13 @@ case class StringSplit(str: Expression, pattern: Expression) override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def nullSafeEval(string: Any, regex: Any): Any = { - val splits = - string.asInstanceOf[UTF8String].toString.split(regex.asInstanceOf[UTF8String].toString, -1) - splits.toSeq.map(UTF8String.fromString) + string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1).toSeq + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (str, pattern) => + s"""${ev.primitive} = scala.collection.JavaConversions.asScalaBuffer( + java.util.Arrays.asList($str.split($pattern, -1)));""") } override def prettyName: String = "split" diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index fc63fe537d226cc046f147fdcd03abd2b27d33f7..ed354f7f877f11d67570d9fd838f90b857f037b0 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -487,6 +487,15 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { return fromBytes(result); } + public UTF8String[] split(UTF8String pattern, int limit) { + String[] splits = toString().split(pattern.toString(), limit); + UTF8String[] res = new UTF8String[splits.length]; + for (int i = 0; i < res.length; i++) { + res[i] = fromString(splits[i]); + } + return res; + } + @Override public String toString() { try { diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index d730b1d1384f568d4ec2454423d67a87c134710a..1f5572c509bdb65509b8b4bce9058910a0eaa6d5 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe.types; import java.io.UnsupportedEncodingException; +import java.util.Arrays; import org.junit.Test; @@ -270,6 +271,16 @@ public class UTF8StringSuite { fromString("æ•°æ®ç –头å™è¡Œè€…å™è¡Œè€…å™è¡Œ"), fromString("æ•°æ®ç –头").rpad(12, fromString("å™è¡Œè€…"))); } + + @Test + public void split() { + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1), + new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi")})); + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), + new UTF8String[]{fromString("ab"), fromString("def,ghi")})); + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), + new UTF8String[]{fromString("ab"), fromString("def,ghi")})); + } @Test public void levenshteinDistance() {