From 00cd92f32f17ca57d47aa2dcc716eb707aaee799 Mon Sep 17 00:00:00 2001
From: HuJiayin <jiayin.hu@intel.com>
Date: Sat, 1 Aug 2015 21:44:57 -0700
Subject: [PATCH] [SPARK-8269] [SQL] string function: initcap

This PR is based on #7208 , thanks to HuJiayin

Closes #7208

Author: HuJiayin <jiayin.hu@intel.com>
Author: Davies Liu <davies@databricks.com>

Closes #7850 from davies/initcap and squashes the following commits:

54472e9 [Davies Liu] fix python test
17ffe51 [Davies Liu] Merge branch 'master' of github.com:apache/spark into initcap
ca46390 [Davies Liu] Merge branch 'master' of github.com:apache/spark into initcap
3a906e4 [Davies Liu] implement title case in UTF8String
8b2506a [HuJiayin] Update functions.py
2cd43e5 [HuJiayin] fix python style check
b616c0e [HuJiayin] add python api
1f5a0ef [HuJiayin] add codegen
7e0c604 [HuJiayin] Merge branch 'master' of https://github.com/apache/spark into initcap
6a0b958 [HuJiayin] add column
c79482d [HuJiayin] support soundex
7ce416b [HuJiayin] support initcap rebase code
---
 python/pyspark/sql/functions.py               | 12 +++
 .../catalyst/analysis/FunctionRegistry.scala  |  1 +
 .../expressions/stringOperations.scala        | 17 ++++
 .../expressions/StringExpressionsSuite.scala  | 12 +++
 .../org/apache/spark/sql/functions.scala      |  9 ++
 .../spark/sql/StringFunctionsSuite.scala      |  9 ++
 .../apache/spark/unsafe/types/UTF8String.java | 88 +++++++++++++++++++
 .../spark/unsafe/types/UTF8StringSuite.java   |  8 ++
 8 files changed, 156 insertions(+)

diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 96975f54ff..a73ecc7d93 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -958,6 +958,18 @@ def substring_index(str, delim, count):
     return Column(sc._jvm.functions.substring_index(_to_java_column(str), delim, count))
 
 
+@ignore_unicode_prefix
+@since(1.5)
+def initcap(col):
+    """Translate the first letter of each word to upper case in the sentence.
+
+    >>> sqlContext.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect()
+    [Row(v=u'Ab Cd')]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.initcap(_to_java_column(col)))
+
+
 @since(1.5)
 def size(col):
     """
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 6e144518bb..8fafd7778a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -178,6 +178,7 @@ object FunctionRegistry {
     expression[Encode]("encode"),
     expression[Decode]("decode"),
     expression[FormatNumber]("format_number"),
+    expression[InitCap]("initcap"),
     expression[Lower]("lcase"),
     expression[Lower]("lower"),
     expression[Length]("length"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 4d78c55497..80c64e5689 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -597,6 +597,23 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
   override def prettyName: String = "format_string"
 }
 
+/**
+ * Returns string, with the first letter of each word in uppercase.
+ * Words are delimited by whitespace.
+ */
+case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+
+  override def inputTypes: Seq[DataType] = Seq(StringType)
+  override def dataType: DataType = StringType
+
+  override def nullSafeEval(string: Any): Any = {
+    string.asInstanceOf[UTF8String].toTitleCase
+  }
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    defineCodeGen(ctx, ev, str => s"$str.toTitleCase()")
+  }
+}
+
 /**
  * Returns the string which repeat the given string value n times.
  */
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 89c1e33420..906be701be 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -377,6 +377,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null))
   }
 
+  test("initcap unit test") {
+    checkEvaluation(InitCap(Literal.create(null, StringType)), null)
+    checkEvaluation(InitCap(Literal("a b")), "A B")
+    checkEvaluation(InitCap(Literal(" a")), " A")
+    checkEvaluation(InitCap(Literal("the test")), "The Test")
+    // scalastyle:off
+    // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+    checkEvaluation(InitCap(Literal("世界")), "世界")
+    // scalastyle:on
+  }
+
+
   test("Levenshtein distance") {
     checkEvaluation(Levenshtein(Literal.create(null, StringType), Literal("")), null)
     checkEvaluation(Levenshtein(Literal(""), Literal.create(null, StringType)), null)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index babfe21879..818aa109f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1786,6 +1786,15 @@ object functions {
     FormatString((lit(format) +: arguments).map(_.expr): _*)
   }
 
+  /**
+   * Returns string, with the first letter of each word in uppercase.
+   * Words are delimited by whitespace.
+   *
+   * @group string_funcs
+   * @since 1.5.0
+   */
+  def initcap(e: Column): Column = InitCap(e.expr)
+
   /**
    * Locate the position of the first occurrence of substr column in the given string.
    * Returns null if either of the arguments are null.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index f40233db0a..1c1be0c3cc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -315,6 +315,15 @@ class StringFunctionsSuite extends QueryTest {
     }
   }
 
+  test("initcap function") {
+    val df = Seq(("ab", "a B")).toDF("l", "r")
+    checkAnswer(
+      df.select(initcap($"l"), initcap($"r")), Row("Ab", "A B"))
+
+    checkAnswer(
+      df.selectExpr("InitCap(l)", "InitCap(r)"), Row("Ab", "A B"))
+  }
+
   test("number format function") {
     val tuple =
       ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 208503d2fd..213dc761bb 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -279,6 +279,29 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
    * Returns the upper case of this string
    */
   public UTF8String toUpperCase() {
+    if (numBytes == 0) {
+      return EMPTY_UTF8;
+    }
+
+    byte[] bytes = new byte[numBytes];
+    bytes[0] = (byte) Character.toTitleCase(getByte(0));
+    for (int i = 0; i < numBytes; i++) {
+      byte b = getByte(i);
+      if (numBytesForFirstByte(b) != 1) {
+        // fallback
+        return toUpperCaseSlow();
+      }
+      int upper = Character.toUpperCase((int) b);
+      if (upper > 127) {
+        // fallback
+        return toUpperCaseSlow();
+      }
+      bytes[i] = (byte) upper;
+    }
+    return fromBytes(bytes);
+  }
+
+  private UTF8String toUpperCaseSlow() {
     return fromString(toString().toUpperCase());
   }
 
@@ -286,9 +309,74 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
    * Returns the lower case of this string
    */
   public UTF8String toLowerCase() {
+    if (numBytes == 0) {
+      return EMPTY_UTF8;
+    }
+
+    byte[] bytes = new byte[numBytes];
+    bytes[0] = (byte) Character.toTitleCase(getByte(0));
+    for (int i = 0; i < numBytes; i++) {
+      byte b = getByte(i);
+      if (numBytesForFirstByte(b) != 1) {
+        // fallback
+        return toLowerCaseSlow();
+      }
+      int lower = Character.toLowerCase((int) b);
+      if (lower > 127) {
+        // fallback
+        return toLowerCaseSlow();
+      }
+      bytes[i] = (byte) lower;
+    }
+    return fromBytes(bytes);
+  }
+
+  private UTF8String toLowerCaseSlow() {
     return fromString(toString().toLowerCase());
   }
 
+  /**
+   * Returns the title case of this string, that could be used as title.
+   */
+  public UTF8String toTitleCase() {
+    if (numBytes == 0) {
+      return EMPTY_UTF8;
+    }
+
+    byte[] bytes = new byte[numBytes];
+    for (int i = 0; i < numBytes; i++) {
+      byte b = getByte(i);
+      if (i == 0 || getByte(i - 1) == ' ') {
+        if (numBytesForFirstByte(b) != 1) {
+          // fallback
+          return toTitleCaseSlow();
+        }
+        int upper = Character.toTitleCase(b);
+        if (upper > 127) {
+          // fallback
+          return toTitleCaseSlow();
+        }
+        bytes[i] = (byte) upper;
+      } else {
+        bytes[i] = b;
+      }
+    }
+    return fromBytes(bytes);
+  }
+
+  private UTF8String toTitleCaseSlow() {
+    StringBuffer sb = new StringBuffer();
+    String s = toString();
+    sb.append(s);
+    sb.setCharAt(0, Character.toTitleCase(sb.charAt(0)));
+    for (int i = 1; i < s.length(); i++) {
+      if (sb.charAt(i - 1) == ' ') {
+        sb.setCharAt(i, Character.toTitleCase(sb.charAt(i)));
+      }
+    }
+    return fromString(sb.toString());
+  }
+
   /**
    * Copy the bytes from the current UTF8String, and make a new UTF8String.
    * @param start the start position of the current UTF8String in bytes.
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index ed50cdcb29..9b3190f8f0 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -114,6 +114,14 @@ public class UTF8StringSuite {
     testUpperandLower("大千世界 数据砖头", "大千世界 数据砖头");
   }
 
+  @Test
+  public void titleCase() {
+    assertEquals(fromString(""), fromString("").toTitleCase());
+    assertEquals(fromString("Ab Bc Cd"), fromString("ab bc cd").toTitleCase());
+    assertEquals(fromString("Ѐ Ё Ђ Ѻ Ώ Ề"), fromString("ѐ ё ђ ѻ ώ ề").toTitleCase());
+    assertEquals(fromString("大千世界 数据砖头"), fromString("大千世界 数据砖头").toTitleCase());
+  }
+
   @Test
   public void concatTest() {
     assertEquals(EMPTY_UTF8, concat());
-- 
GitLab