From a5c2961caaafd751f11bdd406bb6885443d7572e Mon Sep 17 00:00:00 2001
From: Tarek Auel <tarek.auel@gmail.com>
Date: Mon, 29 Jun 2015 11:57:19 -0700
Subject: [PATCH] [SPARK-8235] [SQL] misc function sha / sha1

Jira: https://issues.apache.org/jira/browse/SPARK-8235

I added the support for sha1. If I understood rxin correctly, sha and sha1 should execute the same algorithm, shouldn't they?

Please take a close look on the Python part. This is adopted from #6934

Author: Tarek Auel <tarek.auel@gmail.com>
Author: Tarek Auel <tarek.auel@googlemail.com>

Closes #6963 from tarekauel/SPARK-8235 and squashes the following commits:

f064563 [Tarek Auel] change to shaHex
7ce3cdc [Tarek Auel] rely on automatic cast
a1251d6 [Tarek Auel] Merge remote-tracking branch 'upstream/master' into SPARK-8235
68eb043 [Tarek Auel] added docstring
be5aff1 [Tarek Auel] improved error message
7336c96 [Tarek Auel] added type check
cf23a80 [Tarek Auel] simplified example
ebf75ef [Tarek Auel] [SPARK-8301] updated the python documentation. Removed sha in python and scala
6d6ff0d [Tarek Auel] [SPARK-8233] added docstring
ea191a9 [Tarek Auel] [SPARK-8233] fixed signatureof python function. Added expected type to misc
e3fd7c3 [Tarek Auel] SPARK[8235] added sha to the list of __all__
e5dad4e [Tarek Auel] SPARK[8235] sha / sha1
---
 python/pyspark/sql/functions.py               | 14 +++++++++
 .../catalyst/analysis/FunctionRegistry.scala  |  2 ++
 .../spark/sql/catalyst/expressions/misc.scala | 30 ++++++++++++++++++-
 .../expressions/MiscFunctionsSuite.scala      |  8 +++++
 .../org/apache/spark/sql/functions.scala      | 16 ++++++++++
 .../spark/sql/DataFrameFunctionsSuite.scala   | 12 ++++++++
 6 files changed, 81 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 7d3d036161..45ecd826bd 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -42,6 +42,7 @@ __all__ = [
     'monotonicallyIncreasingId',
     'rand',
     'randn',
+    'sha1',
     'sha2',
     'sparkPartitionId',
     'struct',
@@ -382,6 +383,19 @@ def sha2(col, numBits):
     return Column(jc)
 
 
+@ignore_unicode_prefix
+@since(1.5)
+def sha1(col):
+    """Returns the hex string result of SHA-1.
+
+    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect()
+    [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.sha1(_to_java_column(col))
+    return Column(jc)
+
+
 @since(1.4)
 def sparkPartitionId():
     """A column for partition ID of the Spark task.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 457948a800..b24064d061 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -136,6 +136,8 @@ object FunctionRegistry {
     // misc functions
     expression[Md5]("md5"),
     expression[Sha2]("sha2"),
+    expression[Sha1]("sha1"),
+    expression[Sha1]("sha"),
 
     // aggregate functions
     expression[Average]("avg"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index e80706fc65..9a39165a1f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -21,8 +21,9 @@ import java.security.MessageDigest
 import java.security.NoSuchAlgorithmException
 
 import org.apache.commons.codec.digest.DigestUtils
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, DataType}
+import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
 /**
@@ -140,3 +141,30 @@ case class Sha2(left: Expression, right: Expression)
     """
   }
 }
+
+/**
+ * A function that calculates a sha1 hash value and returns it as a hex string
+ * For input of type [[BinaryType]] or [[StringType]]
+ */
+case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+
+  override def dataType: DataType = StringType
+
+  override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
+
+  override def eval(input: InternalRow): Any = {
+    val value = child.eval(input)
+    if (value == null) {
+      null
+    } else {
+      UTF8String.fromString(DigestUtils.shaHex(value.asInstanceOf[Array[Byte]]))
+    }
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    defineCodeGen(ctx, ev, c =>
+      "org.apache.spark.unsafe.types.UTF8String.fromString" +
+        s"(org.apache.commons.codec.digest.DigestUtils.shaHex($c))"
+    )
+  }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
index 38482c54c6..36e636b5da 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
@@ -31,6 +31,14 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
   }
 
+  test("sha1") {
+    checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")
+    checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
+      "5d211bad8f4ee70e16c7d343a838fc344a1ed961")
+    checkEvaluation(Sha1(Literal.create(null, BinaryType)), null)
+    checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709")
+  }
+
   test("sha2") {
     checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
     checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 355ce0e342..ef92801548 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1414,6 +1414,22 @@ object functions {
    */
   def md5(columnName: String): Column = md5(Column(columnName))
 
+  /**
+   * Calculates the SHA-1 digest and returns the value as a 40 character hex string.
+   *
+   * @group misc_funcs
+   * @since 1.5.0
+   */
+  def sha1(e: Column): Column = Sha1(e.expr)
+
+  /**
+   * Calculates the SHA-1 digest and returns the value as a 40 character hex string.
+   *
+   * @group misc_funcs
+   * @since 1.5.0
+   */
+  def sha1(columnName: String): Column = sha1(Column(columnName))
+
   /**
    * Calculates the SHA-2 family of hash functions and returns the value as a hex string.
    *
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 8baed57a7f..abfd47c811 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -144,6 +144,18 @@ class DataFrameFunctionsSuite extends QueryTest {
       Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c"))
   }
 
+  test("misc sha1 function") {
+    val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b")
+    checkAnswer(
+      df.select(sha1($"a"), sha1("b")),
+      Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8"))
+
+    val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b")
+    checkAnswer(
+      dfEmpty.selectExpr("sha1(a)", "sha1(b)"),
+      Row("da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709"))
+  }
+
   test("misc sha2 function") {
     val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b")
     checkAnswer(
-- 
GitLab