diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 6608036f01318ce700d0aff265e9dd58536253be..e42be85367aeb8821950e37d800c363cc9ebf881 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -593,17 +593,19 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 * Returns a n spaces string. */ case class StringSpace(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { + extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(IntegerType) override def nullSafeEval(s: Any): Any = { - val length = s.asInstanceOf[Integer] + val length = s.asInstanceOf[Int] + UTF8String.blankString(if (length < 0) 0 else length) + } - val spaces = new Array[Byte](if (length < 0) 0 else length) - java.util.Arrays.fill(spaces, ' '.asInstanceOf[Byte]) - UTF8String.fromBytes(spaces) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (length) => + s"""${ev.primitive} = UTF8String.blankString(($length < 0) ? 0 : $length);""") } override def prettyName: String = "space" diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 3eecd657e6ef9559f7e1d289c546a582a6395e28..819639f300177f1c7c0bddc1115453587941dab1 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -20,6 +20,7 @@ package org.apache.spark.unsafe.types; import javax.annotation.Nonnull; import java.io.Serializable; import java.io.UnsupportedEncodingException; +import java.util.Arrays; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -77,6 +78,15 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { } } + /** + * Creates an UTF8String that contains `length` spaces. + */ + public static UTF8String blankString(int length) { + byte[] spaces = new byte[length]; + Arrays.fill(spaces, (byte) ' '); + return fromBytes(spaces); + } + protected UTF8String(Object base, long offset, int size) { this.base = base; this.offset = offset; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 7d0c49e2fb84c22b450e7ed4351717cef227cf5a..6a21c27461163a66d57f3b1a5208ae9ea75b191f 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -286,4 +286,12 @@ public class UTF8StringSuite { assertEquals( UTF8String.fromString("世界åƒä¸–").levenshteinDistance(UTF8String.fromString("åƒa世b")),4); } + + @Test + public void createBlankString() { + assertEquals(fromString(" "), blankString(1)); + assertEquals(fromString(" "), blankString(2)); + assertEquals(fromString(" "), blankString(3)); + assertEquals(fromString(""), blankString(0)); + } }