Skip to content
Snippets Groups Projects
Commit 6e1e2eba authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-8240][SQL] string function: concat

Author: Reynold Xin <rxin@databricks.com>

Closes #7486 from rxin/concat and squashes the following commits:

5217d6e [Reynold Xin] Removed Hive's concat test.
f5cb7a3 [Reynold Xin] Concat is never nullable.
ae4e61f [Reynold Xin] Removed extra import.
fddcbbd [Reynold Xin] Fixed NPE.
22e831c [Reynold Xin] Added missing file.
57a2352 [Reynold Xin] [SPARK-8240][SQL] string function: concat
parent 3d2134fc
No related branches found
No related tags found
No related merge requests found
Showing
with 421 additions and 247 deletions
...@@ -152,6 +152,7 @@ object FunctionRegistry { ...@@ -152,6 +152,7 @@ object FunctionRegistry {
// string functions // string functions
expression[Ascii]("ascii"), expression[Ascii]("ascii"),
expression[Base64]("base64"), expression[Base64]("base64"),
expression[Concat]("concat"),
expression[Encode]("encode"), expression[Encode]("encode"),
expression[Decode]("decode"), expression[Decode]("decode"),
expression[FormatNumber]("format_number"), expression[FormatNumber]("format_number"),
......
...@@ -27,6 +27,43 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ ...@@ -27,6 +27,43 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.unsafe.types.UTF8String
////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines expressions for string operations.
////////////////////////////////////////////////////////////////////////////////////////////////////
/**
* An expression that concatenates multiple input strings into a single string.
* Input expressions that are evaluated to nulls are skipped.
*
* For example, `concat("a", null, "b")` is evaluated to `"ab"`.
*
* Note that this is different from Hive since Hive outputs null if any input is null.
* We never output null.
*/
case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
override def dataType: DataType = StringType
override def nullable: Boolean = false
override def foldable: Boolean = children.forall(_.foldable)
override def eval(input: InternalRow): Any = {
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
UTF8String.concat(inputs : _*)
}
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val evals = children.map(_.gen(ctx))
val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ")
evals.map(_.code).mkString("\n") + s"""
boolean ${ev.isNull} = false;
UTF8String ${ev.primitive} = UTF8String.concat($inputs);
"""
}
}
trait StringRegexExpression extends ImplicitCastInputTypes { trait StringRegexExpression extends ImplicitCastInputTypes {
self: BinaryExpression => self: BinaryExpression =>
......
...@@ -22,7 +22,29 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ ...@@ -22,7 +22,29 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("concat") {
def testConcat(inputs: String*): Unit = {
val expected = inputs.filter(_ != null).mkString
checkEvaluation(Concat(inputs.map(Literal.create(_, StringType))), expected, EmptyRow)
}
testConcat()
testConcat(null)
testConcat("")
testConcat("ab")
testConcat("a", "b")
testConcat("a", "b", "C")
testConcat("a", null, "C")
testConcat("a", null, null)
testConcat(null, null, null)
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
testConcat("数据", null, "砖头")
// scalastyle:on
}
test("StringComparison") { test("StringComparison") {
val row = create_row("abc", null) val row = create_row("abc", null)
......
...@@ -1710,6 +1710,28 @@ object functions { ...@@ -1710,6 +1710,28 @@ object functions {
// String functions // String functions
////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////
/**
* Concatenates input strings together into a single string.
*
* @group string_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def concat(exprs: Column*): Column = Concat(exprs.map(_.expr))
/**
* Concatenates input strings together into a single string.
*
* This is the variant of concat that takes in the column names.
*
* @group string_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def concat(columnName: String, columnNames: String*): Column = {
concat((columnName +: columnNames).map(Column.apply): _*)
}
/** /**
* Computes the length of a given string / binary value. * Computes the length of a given string / binary value.
* *
......
...@@ -208,169 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest { ...@@ -208,169 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(2743272264L, 2180413220L)) Row(2743272264L, 2180413220L))
} }
test("Levenshtein distance") {
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1)))
checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1)))
}
test("string ascii function") {
val df = Seq(("abc", "")).toDF("a", "b")
checkAnswer(
df.select(ascii($"a"), ascii("b")),
Row(97, 0))
checkAnswer(
df.selectExpr("ascii(a)", "ascii(b)"),
Row(97, 0))
}
test("string base64/unbase64 function") {
val bytes = Array[Byte](1, 2, 3, 4)
val df = Seq((bytes, "AQIDBA==")).toDF("a", "b")
checkAnswer(
df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")),
Row("AQIDBA==", "AQIDBA==", bytes, bytes))
checkAnswer(
df.selectExpr("base64(a)", "unbase64(b)"),
Row("AQIDBA==", bytes))
}
test("string encode/decode function") {
val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116)
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c")
checkAnswer(
df.select(
encode($"a", "utf-8"),
encode("a", "utf-8"),
decode($"c", "utf-8"),
decode("c", "utf-8")),
Row(bytes, bytes, "大千世界", "大千世界"))
checkAnswer(
df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"),
Row(bytes, "大千世界"))
// scalastyle:on
}
test("string trim functions") {
val df = Seq((" example ", "")).toDF("a", "b")
checkAnswer(
df.select(ltrim($"a"), rtrim($"a"), trim($"a")),
Row("example ", " example", "example"))
checkAnswer(
df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"),
Row("example ", " example", "example"))
}
test("string formatString function") {
val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c")
checkAnswer(
df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
Row("aa123cc", "aa123cc"))
checkAnswer(
df.selectExpr("printf(a, b, c)"),
Row("aa123cc"))
}
test("string instr function") {
val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c")
checkAnswer(
df.select(instr($"a", $"b"), instr("a", "b")),
Row(1, 1))
checkAnswer(
df.selectExpr("instr(a, b)"),
Row(1))
}
test("string locate function") {
val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")
checkAnswer(
df.select(
locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1),
locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")),
Row(1, 1, 2, 2, 2, 2))
checkAnswer(
df.selectExpr("locate(b, a)", "locate(b, a, d)"),
Row(1, 2))
}
test("string padding functions") {
val df = Seq(("hi", 5, "??")).toDF("a", "b", "c")
checkAnswer(
df.select(
lpad($"a", $"b", $"c"), rpad("a", "b", "c"),
lpad($"a", 1, $"c"), rpad("a", 1, "c")),
Row("???hi", "hi???", "h", "h"))
checkAnswer(
df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"),
Row("???hi", "hi???", "h", "h"))
}
test("string repeat function") {
val df = Seq(("hi", 2)).toDF("a", "b")
checkAnswer(
df.select(
repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")),
Row("hihi", "hihi", "hihi", "hihi"))
checkAnswer(
df.selectExpr("repeat(a, 2)", "repeat(a, b)"),
Row("hihi", "hihi"))
}
test("string reverse function") {
val df = Seq(("hi", "hhhi")).toDF("a", "b")
checkAnswer(
df.select(reverse($"a"), reverse("b")),
Row("ih", "ihhh"))
checkAnswer(
df.selectExpr("reverse(b)"),
Row("ihhh"))
}
test("string space function") {
val df = Seq((2, 3)).toDF("a", "b")
checkAnswer(
df.select(space($"a"), space("b")),
Row(" ", " "))
checkAnswer(
df.selectExpr("space(b)"),
Row(" "))
}
test("string split function") {
val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b")
checkAnswer(
df.select(
split($"a", "[1-9]+"),
split("a", "[1-9]+")),
Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc")))
checkAnswer(
df.selectExpr("split(a, '[1-9]+')"),
Row(Seq("aa", "bb", "cc")))
}
test("conditional function: least") { test("conditional function: least") {
checkAnswer( checkAnswer(
testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1), testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1),
...@@ -430,83 +267,4 @@ class DataFrameFunctionsSuite extends QueryTest { ...@@ -430,83 +267,4 @@ class DataFrameFunctionsSuite extends QueryTest {
) )
} }
test("string / binary length function") {
val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c")
checkAnswer(
df.select(length($"a"), length("a"), length($"b"), length("b")),
Row(3, 3, 4, 4))
checkAnswer(
df.selectExpr("length(a)", "length(b)"),
Row(3, 4))
intercept[AnalysisException] {
checkAnswer(
df.selectExpr("length(c)"), // int type of the argument is unacceptable
Row("5.0000"))
}
}
test("number format function") {
val tuple =
("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
3.13223f, 4, 5L, 6.48173d, Decimal(7.128381))
val df =
Seq(tuple)
.toDF(
"a", // string "aa"
"b", // byte 1
"c", // short 2
"d", // float 3.13223f
"e", // integer 4
"f", // long 5L
"g", // double 6.48173d
"h") // decimal 7.128381
checkAnswer(
df.select(
format_number($"f", 4),
format_number("f", 4)),
Row("5.0000", "5.0000"))
checkAnswer(
df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer
Row("1.0000"))
checkAnswer(
df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer
Row("2.0000"))
checkAnswer(
df.selectExpr("format_number(d, e)"), // convert the 1st argument to double
Row("3.1322"))
checkAnswer(
df.selectExpr("format_number(e, e)"), // not convert anything
Row("4.0000"))
checkAnswer(
df.selectExpr("format_number(f, e)"), // not convert anything
Row("5.0000"))
checkAnswer(
df.selectExpr("format_number(g, e)"), // not convert anything
Row("6.4817"))
checkAnswer(
df.selectExpr("format_number(h, e)"), // not convert anything
Row("7.1284"))
intercept[AnalysisException] {
checkAnswer(
df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable
Row("5.0000"))
}
intercept[AnalysisException] {
checkAnswer(
df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable
Row("5.0000"))
}
}
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.Decimal
class StringFunctionsSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
test("string concat") {
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")
checkAnswer(
df.select(concat($"a", $"b", $"c")),
Row("ab"))
checkAnswer(
df.selectExpr("concat(a, b, c)"),
Row("ab"))
}
test("string Levenshtein distance") {
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1)))
checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1)))
}
test("string ascii function") {
val df = Seq(("abc", "")).toDF("a", "b")
checkAnswer(
df.select(ascii($"a"), ascii("b")),
Row(97, 0))
checkAnswer(
df.selectExpr("ascii(a)", "ascii(b)"),
Row(97, 0))
}
test("string base64/unbase64 function") {
val bytes = Array[Byte](1, 2, 3, 4)
val df = Seq((bytes, "AQIDBA==")).toDF("a", "b")
checkAnswer(
df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")),
Row("AQIDBA==", "AQIDBA==", bytes, bytes))
checkAnswer(
df.selectExpr("base64(a)", "unbase64(b)"),
Row("AQIDBA==", bytes))
}
test("string encode/decode function") {
val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116)
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c")
checkAnswer(
df.select(
encode($"a", "utf-8"),
encode("a", "utf-8"),
decode($"c", "utf-8"),
decode("c", "utf-8")),
Row(bytes, bytes, "大千世界", "大千世界"))
checkAnswer(
df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"),
Row(bytes, "大千世界"))
// scalastyle:on
}
test("string trim functions") {
val df = Seq((" example ", "")).toDF("a", "b")
checkAnswer(
df.select(ltrim($"a"), rtrim($"a"), trim($"a")),
Row("example ", " example", "example"))
checkAnswer(
df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"),
Row("example ", " example", "example"))
}
test("string formatString function") {
val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c")
checkAnswer(
df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
Row("aa123cc", "aa123cc"))
checkAnswer(
df.selectExpr("printf(a, b, c)"),
Row("aa123cc"))
}
test("string instr function") {
val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c")
checkAnswer(
df.select(instr($"a", $"b"), instr("a", "b")),
Row(1, 1))
checkAnswer(
df.selectExpr("instr(a, b)"),
Row(1))
}
test("string locate function") {
val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")
checkAnswer(
df.select(
locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1),
locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")),
Row(1, 1, 2, 2, 2, 2))
checkAnswer(
df.selectExpr("locate(b, a)", "locate(b, a, d)"),
Row(1, 2))
}
test("string padding functions") {
val df = Seq(("hi", 5, "??")).toDF("a", "b", "c")
checkAnswer(
df.select(
lpad($"a", $"b", $"c"), rpad("a", "b", "c"),
lpad($"a", 1, $"c"), rpad("a", 1, "c")),
Row("???hi", "hi???", "h", "h"))
checkAnswer(
df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"),
Row("???hi", "hi???", "h", "h"))
}
test("string repeat function") {
val df = Seq(("hi", 2)).toDF("a", "b")
checkAnswer(
df.select(
repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")),
Row("hihi", "hihi", "hihi", "hihi"))
checkAnswer(
df.selectExpr("repeat(a, 2)", "repeat(a, b)"),
Row("hihi", "hihi"))
}
test("string reverse function") {
val df = Seq(("hi", "hhhi")).toDF("a", "b")
checkAnswer(
df.select(reverse($"a"), reverse("b")),
Row("ih", "ihhh"))
checkAnswer(
df.selectExpr("reverse(b)"),
Row("ihhh"))
}
test("string space function") {
val df = Seq((2, 3)).toDF("a", "b")
checkAnswer(
df.select(space($"a"), space("b")),
Row(" ", " "))
checkAnswer(
df.selectExpr("space(b)"),
Row(" "))
}
test("string split function") {
val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b")
checkAnswer(
df.select(
split($"a", "[1-9]+"),
split("a", "[1-9]+")),
Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc")))
checkAnswer(
df.selectExpr("split(a, '[1-9]+')"),
Row(Seq("aa", "bb", "cc")))
}
test("string / binary length function") {
val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c")
checkAnswer(
df.select(length($"a"), length("a"), length($"b"), length("b")),
Row(3, 3, 4, 4))
checkAnswer(
df.selectExpr("length(a)", "length(b)"),
Row(3, 4))
intercept[AnalysisException] {
checkAnswer(
df.selectExpr("length(c)"), // int type of the argument is unacceptable
Row("5.0000"))
}
}
test("number format function") {
val tuple =
("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
3.13223f, 4, 5L, 6.48173d, Decimal(7.128381))
val df =
Seq(tuple)
.toDF(
"a", // string "aa"
"b", // byte 1
"c", // short 2
"d", // float 3.13223f
"e", // integer 4
"f", // long 5L
"g", // double 6.48173d
"h") // decimal 7.128381
checkAnswer(
df.select(
format_number($"f", 4),
format_number("f", 4)),
Row("5.0000", "5.0000"))
checkAnswer(
df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer
Row("1.0000"))
checkAnswer(
df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer
Row("2.0000"))
checkAnswer(
df.selectExpr("format_number(d, e)"), // convert the 1st argument to double
Row("3.1322"))
checkAnswer(
df.selectExpr("format_number(e, e)"), // not convert anything
Row("4.0000"))
checkAnswer(
df.selectExpr("format_number(f, e)"), // not convert anything
Row("5.0000"))
checkAnswer(
df.selectExpr("format_number(g, e)"), // not convert anything
Row("6.4817"))
checkAnswer(
df.selectExpr("format_number(h, e)"), // not convert anything
Row("7.1284"))
intercept[AnalysisException] {
checkAnswer(
df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable
Row("5.0000"))
}
intercept[AnalysisException] {
checkAnswer(
df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable
Row("5.0000"))
}
}
}
...@@ -256,6 +256,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { ...@@ -256,6 +256,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"timestamp_2", "timestamp_2",
"timestamp_udf", "timestamp_udf",
// Hive outputs NULL if any concat input has null. We never output null for concat.
"udf_concat",
// Unlike Hive, we do support log base in (0, 1.0], therefore disable this // Unlike Hive, we do support log base in (0, 1.0], therefore disable this
"udf7" "udf7"
) )
...@@ -846,7 +849,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { ...@@ -846,7 +849,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_case", "udf_case",
"udf_ceil", "udf_ceil",
"udf_ceiling", "udf_ceiling",
"udf_concat",
"udf_concat_insert1", "udf_concat_insert1",
"udf_concat_insert2", "udf_concat_insert2",
"udf_concat_ws", "udf_concat_ws",
......
...@@ -21,6 +21,7 @@ import javax.annotation.Nonnull; ...@@ -21,6 +21,7 @@ import javax.annotation.Nonnull;
import java.io.Serializable; import java.io.Serializable;
import java.io.UnsupportedEncodingException; import java.io.UnsupportedEncodingException;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.ByteArrayMethods;
import static org.apache.spark.unsafe.PlatformDependent.*; import static org.apache.spark.unsafe.PlatformDependent.*;
...@@ -322,7 +323,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { ...@@ -322,7 +323,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
} }
i += numBytesForFirstByte(getByte(i)); i += numBytesForFirstByte(getByte(i));
c += 1; c += 1;
} while(i < numBytes); } while (i < numBytes);
return -1; return -1;
} }
...@@ -395,6 +396,39 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { ...@@ -395,6 +396,39 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
} }
} }
/**
* Concatenates input strings together into a single string. A null input is skipped.
* For example, concat("a", null, "c") would yield "ac".
*/
public static UTF8String concat(UTF8String... inputs) {
if (inputs == null) {
return fromBytes(new byte[0]);
}
// Compute the total length of the result.
int totalLength = 0;
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
totalLength += inputs[i].numBytes;
}
}
// Allocate a new byte array, and copy the inputs one by one into it.
final byte[] result = new byte[totalLength];
int offset = 0;
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
int len = inputs[i].numBytes;
PlatformDependent.copyMemory(
inputs[i].base, inputs[i].offset,
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
len);
offset += len;
}
}
return fromBytes(result);
}
@Override @Override
public String toString() { public String toString() {
try { try {
...@@ -413,7 +447,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { ...@@ -413,7 +447,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
} }
@Override @Override
public int compareTo(final UTF8String other) { public int compareTo(@Nonnull final UTF8String other) {
int len = Math.min(numBytes, other.numBytes); int len = Math.min(numBytes, other.numBytes);
// TODO: compare 8 bytes as unsigned long // TODO: compare 8 bytes as unsigned long
for (int i = 0; i < len; i ++) { for (int i = 0; i < len; i ++) {
...@@ -434,7 +468,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { ...@@ -434,7 +468,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
public boolean equals(final Object other) { public boolean equals(final Object other) {
if (other instanceof UTF8String) { if (other instanceof UTF8String) {
UTF8String o = (UTF8String) other; UTF8String o = (UTF8String) other;
if (numBytes != o.numBytes){ if (numBytes != o.numBytes) {
return false; return false;
} }
return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes);
......
...@@ -86,6 +86,20 @@ public class UTF8StringSuite { ...@@ -86,6 +86,20 @@ public class UTF8StringSuite {
testUpperandLower("大千世界 数据砖头", "大千世界 数据砖头"); testUpperandLower("大千世界 数据砖头", "大千世界 数据砖头");
} }
@Test
public void concatTest() {
assertEquals(concat(), fromString(""));
assertEquals(concat(null), fromString(""));
assertEquals(concat(fromString("")), fromString(""));
assertEquals(concat(fromString("ab")), fromString("ab"));
assertEquals(concat(fromString("a"), fromString("b")), fromString("ab"));
assertEquals(concat(fromString("a"), fromString("b"), fromString("c")), fromString("abc"));
assertEquals(concat(fromString("a"), null, fromString("c")), fromString("ac"));
assertEquals(concat(fromString("a"), null, null), fromString("a"));
assertEquals(concat(null, null, null), fromString(""));
assertEquals(concat(fromString("数据"), fromString("砖头")), fromString("数据砖头"));
}
@Test @Test
public void contains() { public void contains() {
assertTrue(fromString("").contains(fromString(""))); assertTrue(fromString("").contains(fromString("")));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment