From 37e4d92142a6309e2df7d36883e0c7892c3d792d Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Mon, 6 Jul 2015 13:31:31 -0700
Subject: [PATCH] [SPARK-8784] [SQL] Add Python API for hex and unhex

Add Python API for hex/unhex,  also cleanup Hex/Unhex

Author: Davies Liu <davies@databricks.com>

Closes #7223 from davies/hex and squashes the following commits:

6f1249d [Davies Liu] no explicit rule to cast string into binary
711a6ed [Davies Liu] fix test
f9fe5a3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into hex
f032fbb [Davies Liu] Merge branch 'hex' of github.com:davies/spark into hex
49e325f [Davies Liu] Merge branch 'master' of github.com:apache/spark into hex
b31fc9a [Davies Liu] Update math.scala
25156b7 [Davies Liu] address comments and fix test
c3af78c [Davies Liu] address commments
1a24082 [Davies Liu] Add Python API for hex and unhex
---
 python/pyspark/sql/functions.py               | 28 +++++++
 .../catalyst/analysis/FunctionRegistry.scala  |  2 +-
 .../spark/sql/catalyst/expressions/math.scala | 83 ++++++++++---------
 .../expressions/MathFunctionsSuite.scala      | 25 ++++--
 .../org/apache/spark/sql/functions.scala      |  2 +-
 5 files changed, 93 insertions(+), 47 deletions(-)

diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 49dd0332af..dca39fa833 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -395,6 +395,34 @@ def randn(seed=None):
     return Column(jc)
 
 
+@ignore_unicode_prefix
+@since(1.5)
+def hex(col):
+    """Computes hex value of the given column, which could be StringType,
+    BinaryType, IntegerType or LongType.
+
+    >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect()
+    [Row(hex(a)=u'414243', hex(b)=u'3')]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.hex(_to_java_column(col))
+    return Column(jc)
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def unhex(col):
+    """Inverse of hex. Interprets each pair of characters as a hexadecimal number
+    and converts to the byte representation of number.
+
+    >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect()
+    [Row(unhex(a)=bytearray(b'ABC'))]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.unhex(_to_java_column(col))
+    return Column(jc)
+
+
 @ignore_unicode_prefix
 @since(1.5)
 def sha1(col):
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 92a50e7092..fef2763530 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -168,7 +168,7 @@ object FunctionRegistry {
     expression[Substring]("substring"),
     expression[UnBase64]("unbase64"),
     expression[Upper]("ucase"),
-    expression[UnHex]("unhex"),
+    expression[Unhex]("unhex"),
     expression[Upper]("upper"),
 
     // datetime functions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 45b7e4d340..9250045398 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -298,6 +298,21 @@ case class Bin(child: Expression)
   }
 }
 
+object Hex {
+  val hexDigits = Array[Char](
+    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
+  ).map(_.toByte)
+
+  // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15
+  val unhexDigits = {
+    val array = Array.fill[Byte](128)(-1)
+    (0 to 9).foreach(i => array('0' + i) = i.toByte)
+    (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte)
+    (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte)
+    array
+  }
+}
+
 /**
  * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format.
  * Otherwise if the number is a STRING, it converts each character into its hex representation
@@ -307,7 +322,7 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes
   // TODO: Create code-gen version.
 
   override def inputTypes: Seq[AbstractDataType] =
-    Seq(TypeCollection(LongType, StringType, BinaryType))
+    Seq(TypeCollection(LongType, BinaryType, StringType))
 
   override def dataType: DataType = StringType
 
@@ -319,30 +334,18 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes
       child.dataType match {
         case LongType => hex(num.asInstanceOf[Long])
         case BinaryType => hex(num.asInstanceOf[Array[Byte]])
-        case StringType => hex(num.asInstanceOf[UTF8String])
+        case StringType => hex(num.asInstanceOf[UTF8String].getBytes)
       }
     }
   }
 
-  /**
-   * Converts every character in s to two hex digits.
-   */
-  private def hex(str: UTF8String): UTF8String = {
-    hex(str.getBytes)
-  }
-
-  private def hex(bytes: Array[Byte]): UTF8String = {
-    doHex(bytes, bytes.length)
-  }
-
-  private def doHex(bytes: Array[Byte], length: Int): UTF8String = {
+  private[this] def hex(bytes: Array[Byte]): UTF8String = {
+    val length = bytes.length
     val value = new Array[Byte](length * 2)
     var i = 0
     while (i < length) {
-      value(i * 2) = Character.toUpperCase(Character.forDigit(
-        (bytes(i) & 0xF0) >>> 4, 16)).toByte
-      value(i * 2 + 1) = Character.toUpperCase(Character.forDigit(
-        bytes(i) & 0x0F, 16)).toByte
+      value(i * 2) = Hex.hexDigits((bytes(i) & 0xF0) >> 4)
+      value(i * 2 + 1) = Hex.hexDigits(bytes(i) & 0x0F)
       i += 1
     }
     UTF8String.fromBytes(value)
@@ -355,24 +358,23 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes
     var len = 0
     do {
       len += 1
-      value(value.length - len) =
-        Character.toUpperCase(Character.forDigit((numBuf & 0xF).toInt, 16)).toByte
+      value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt)
       numBuf >>>= 4
     } while (numBuf != 0)
     UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length))
   }
 }
 
-
 /**
  * Performs the inverse operation of HEX.
  * Resulting characters are returned as a byte array.
  */
-case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTypes {
   // TODO: Create code-gen version.
 
   override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
 
+  override def nullable: Boolean = true
   override def dataType: DataType = BinaryType
 
   override def eval(input: InternalRow): Any = {
@@ -384,26 +386,31 @@ case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTyp
     }
   }
 
-  private val unhexDigits = {
-    val array = Array.fill[Byte](128)(-1)
-    (0 to 9).foreach(i => array('0' + i) = i.toByte)
-    (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte)
-    (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte)
-    array
-  }
-
-  private def unhex(inputBytes: Array[Byte]): Array[Byte] = {
-    var bytes = inputBytes
+  private[this] def unhex(bytes: Array[Byte]): Array[Byte] = {
+    val out = new Array[Byte]((bytes.length + 1) >> 1)
+    var i = 0
     if ((bytes.length & 0x01) != 0) {
-      bytes = '0'.toByte +: bytes
+      // padding with '0'
+      if (bytes(0) < 0) {
+        return null
+      }
+      val v = Hex.unhexDigits(bytes(0))
+      if (v == -1) {
+        return null
+      }
+      out(0) = v
+      i += 1
     }
-    val out = new Array[Byte](bytes.length >> 1)
     // two characters form the hex value.
-    var i = 0
     while (i < bytes.length) {
-      val first = unhexDigits(bytes(i))
-      val second = unhexDigits(bytes(i + 1))
-      if (first == -1 || second == -1) { return null}
+      if (bytes(i) < 0 || bytes(i + 1) < 0) {
+        return null
+      }
+      val first = Hex.unhexDigits(bytes(i))
+      val second = Hex.unhexDigits(bytes(i + 1))
+      if (first == -1 || second == -1) {
+        return null
+      }
       out(i / 2) = (((first << 4) | second) & 0xFF).toByte
       i += 2
     }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 03d8400cf3..7ca9e30b2b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -21,8 +21,7 @@ import com.google.common.math.LongMath
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{DataType, LongType}
-import org.apache.spark.sql.types.{IntegerType, DoubleType}
+import org.apache.spark.sql.types._
 
 class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
@@ -271,20 +270,32 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("hex") {
+    checkEvaluation(Hex(Literal.create(null, LongType)), null)
+    checkEvaluation(Hex(Literal(28L)), "1C")
+    checkEvaluation(Hex(Literal(-28L)), "FFFFFFFFFFFFFFE4")
     checkEvaluation(Hex(Literal(100800200404L)), "177828FED4")
     checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C")
-    checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578")
+    checkEvaluation(Hex(Literal.create(null, BinaryType)), null)
     checkEvaluation(Hex(Literal("helloHex".getBytes())), "68656C6C6F486578")
     // scalastyle:off
     // Turn off scala style for non-ascii chars
-    checkEvaluation(Hex(Literal("三重的")), "E4B889E9878DE79A84")
+    checkEvaluation(Hex(Literal("三重的".getBytes("UTF8"))), "E4B889E9878DE79A84")
     // scalastyle:on
   }
 
   test("unhex") {
-    checkEvaluation(UnHex(Literal("737472696E67")), "string".getBytes)
-    checkEvaluation(UnHex(Literal("")), new Array[Byte](0))
-    checkEvaluation(UnHex(Literal("0")), Array[Byte](0))
+    checkEvaluation(Unhex(Literal.create(null, StringType)), null)
+    checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes)
+    checkEvaluation(Unhex(Literal("")), new Array[Byte](0))
+    checkEvaluation(Unhex(Literal("F")), Array[Byte](15))
+    checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1))
+    checkEvaluation(Unhex(Literal("GG")), null)
+    // scalastyle:off
+    // Turn off scala style for non-ascii chars
+    checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes("UTF-8"))
+    checkEvaluation(Unhex(Literal("三重的")), null)
+
+    // scalastyle:on
   }
 
   test("hypot") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index f80291776f..4da9ffc495 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1095,7 +1095,7 @@ object functions {
    * @group math_funcs
    * @since 1.5.0
    */
-  def unhex(column: Column): Column = UnHex(column.expr)
+  def unhex(column: Column): Column = Unhex(column.expr)
 
   /**
    * Inverse of hex. Interprets each pair of characters as a hexadecimal number
-- 
GitLab