diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5fc1cc2cae10a19bcac6addd3dadf6c73a0944e2..7d038b8d9e89b5ed88abc1976028b43e5c70f004 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1125,6 +1125,43 @@ def hash(*cols): return Column(jc) +@ignore_unicode_prefix +@since(2.0) +def aes_encrypt(input, key): + """ + Encrypts input of given column using AES. Key lengths of 128, 192 or 256 bits can be used. 192 + and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic- + tion Policy Files are installed. If input is invalid, key length is not one of the permitted + values or using 192/256 bits key before installing JCE, an exception will be thrown. + + >>> df = sqlContext.createDataFrame([('ABC','1234567890123456')], ['input','key']) + >>> df.select(base64(aes_encrypt(df.input, df.key)).alias('aes')).collect() + [Row(aes=u'y6Ss+zCYObpCbgfWfyNWTw==')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.aes_encrypt(_to_java_column(input), _to_java_column(key)) + return Column(jc) + + +@ignore_unicode_prefix +@since(2.0) +def aes_decrypt(input, key): + """ + Decrypts input of given column using AES. Key lengths of 128, 192 or 256 bits can be used. 192 + and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic- + tion Policy Files are installed. If input is invalid, key length is not one of the permitted + values or using 192/256 bits key before installing JCE, an exception will be thrown. + + >>> df = sqlContext.createDataFrame([(u'y6Ss+zCYObpCbgfWfyNWTw==','1234567890123456')], \ + ['input','key']) + >>> df.select(aes_decrypt(unbase64(df.input), df.key).alias('aes')).collect() + [Row(aes=u'ABC')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.aes_decrypt(_to_java_column(input), _to_java_column(key)) + return Column(jc) + + # ---------------------- String/Binary functions ------------------------------ _string_functions = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1be97c7b81197336c3cdad0f94dfd5c176255036..ae09c3d71f8833b77fa5546bb116c9e6ea89f2b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -278,6 +278,8 @@ object FunctionRegistry { expression[ArrayContains]("array_contains"), // misc functions + expression[AesEncrypt]("aes_encrypt"), + expression[AesDecrypt]("aes_decrypt"), expression[Crc32]("crc32"), expression[Md5]("md5"), expression[Murmur3Hash]("hash"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index dcbb594afd86e984e6dd9b4ab260ad0364c1c932..3b66f5797be64af533a621190f548aa2b2111ce9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 +import javax.crypto.Cipher +import javax.crypto.spec.SecretKeySpec import org.apache.commons.codec.digest.DigestUtils @@ -441,3 +443,90 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { """.stripMargin) } } + +/** + * A function that encrypts input using AES. Key lengths of 128, 192 or 256 bits can be used. 192 + * and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic- + * tion Policy Files are installed. If either argument is NULL, the result will also be null. If + * input is invalid, key length is not one of the permitted values or using 192/256 bits key before + * installing JCE, an exception will be thrown. + */ +@ExpressionDescription( + usage = "_FUNC_(input, key) - Encrypts input using AES.", + extended = "> SELECT Base64(_FUNC_('ABC', '1234567890123456'));\n 'y6Ss+zCYObpCbgfWfyNWTw=='") +case class AesEncrypt(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = BinaryType + override def inputTypes: Seq[DataType] = Seq(BinaryType, BinaryType) + + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val cipher = Cipher.getInstance("AES") + val secretKey: SecretKeySpec = new SecretKeySpec(input2.asInstanceOf[Array[Byte]], 0, + input2.asInstanceOf[Array[Byte]].length, "AES") + cipher.init(Cipher.ENCRYPT_MODE, secretKey) + cipher.doFinal(input1.asInstanceOf[Array[Byte]], 0, input1.asInstanceOf[Array[Byte]].length) + } + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + nullSafeCodeGen(ctx, ev, (str, key) => { + val Cipher = "javax.crypto.Cipher" + val SecretKeySpec = "javax.crypto.spec.SecretKeySpec" + s""" + try { + $Cipher cipher = $Cipher.getInstance("AES"); + $SecretKeySpec secret = new $SecretKeySpec($key, 0, $key.length, "AES"); + cipher.init($Cipher.ENCRYPT_MODE, secret); + ${ev.value} = cipher.doFinal($str, 0, $str.length); + } catch (java.security.GeneralSecurityException e) { + org.apache.spark.unsafe.Platform.throwException(e); + } + """ + }) + } +} + +/** + * A function that decrypts input using AES. Key lengths of 128, 192 or 256 bits can be used. 192 + * and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic- + * tion Policy Files are installed. If either argument is NULL, the result will also be null. If + * input is invalid, key length is not one of the permitted values or using 192/256 bits key before + * installing JCE, an exception will be thrown. + */ +@ExpressionDescription( + usage = "_FUNC_(input, key) - Decrypts input using AES.", + extended = "> SELECT _FUNC_(UnBase64('y6Ss+zCYObpCbgfWfyNWTw=='),'1234567890123456');\n 'ABC'") +case class AesDecrypt(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(BinaryType, BinaryType) + + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val cipher = Cipher.getInstance("AES") + val secretKey = new SecretKeySpec(input2.asInstanceOf[Array[Byte]], 0, + input2.asInstanceOf[Array[Byte]].length, "AES") + + cipher.init(Cipher.DECRYPT_MODE, secretKey) + UTF8String.fromBytes( + cipher.doFinal(input1.asInstanceOf[Array[Byte]], 0, + input1.asInstanceOf[Array[Byte]].length)) + } + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + nullSafeCodeGen(ctx, ev, (str, key) => { + val Cipher = "javax.crypto.Cipher" + val SecretKeySpec = "javax.crypto.spec.SecretKeySpec" + s""" + try { + $Cipher cipher = $Cipher.getInstance("AES"); + $SecretKeySpec secret = new $SecretKeySpec($key, 0, $key.length, "AES"); + cipher.init($Cipher.DECRYPT_MODE, secret); + ${ev.value} = UTF8String.fromBytes(cipher.doFinal($str, 0, $str.length)); + } catch (java.security.GeneralSecurityException e) { + org.apache.spark.unsafe.Platform.throwException(e); + } + """ + }) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 75131a6170222d4063acd8ae5afc3ec14ca8272b..67f2dc457d3333f736cbfe1154538dbd1c1a0de2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -132,4 +132,88 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + + test("aesEncrypt") { + val expr1 = AesEncrypt(Literal("ABC".getBytes), Literal("1234567890123456".getBytes)) + val expr2 = AesEncrypt(Literal("".getBytes), Literal("1234567890123456".getBytes)) + + checkEvaluation(Base64(expr1), "y6Ss+zCYObpCbgfWfyNWTw==") + checkEvaluation(Base64(expr2), "BQGHoM3lqYcsurCRq3PlUw==") + + // input is null + checkEvaluation(AesEncrypt(Literal.create(null, BinaryType), + Literal("1234567890123456".getBytes)), null) + // key is null + checkEvaluation(AesEncrypt(Literal("ABC".getBytes), + Literal.create(null, BinaryType)), null) + // both are null + checkEvaluation(AesEncrypt(Literal.create(null, BinaryType), + Literal.create(null, BinaryType)), null) + + val expr3 = AesEncrypt(Literal("ABC".getBytes), Literal("1234567890".getBytes)) + // key length (80 bits) is not one of the permitted values (128, 192 or 256 bits) + intercept[java.security.InvalidKeyException] { + evaluate(expr3) + } + intercept[java.security.InvalidKeyException] { + UnsafeProjection.create(expr3 :: Nil).apply(null) + } + } + + test("aesDecrypt") { + val expr1 = AesDecrypt(UnBase64(Literal("y6Ss+zCYObpCbgfWfyNWTw==")), + Literal("1234567890123456".getBytes)) + val expr2 = AesDecrypt(UnBase64(Literal("BQGHoM3lqYcsurCRq3PlUw==")), + Literal("1234567890123456".getBytes)) + + checkEvaluation(expr1, "ABC") + checkEvaluation(expr2, "") + + // input is null + checkEvaluation(AesDecrypt(UnBase64(Literal.create(null, StringType)), + Literal("1234567890123456".getBytes)), null) + // key is null + checkEvaluation(AesDecrypt(UnBase64(Literal("y6Ss+zCYObpCbgfWfyNWTw==")), + Literal.create(null, BinaryType)), null) + // both are null + checkEvaluation(AesDecrypt(UnBase64(Literal.create(null, StringType)), + Literal.create(null, BinaryType)), null) + + val expr3 = AesDecrypt(UnBase64(Literal("y6Ss+zCYObpCbgfWfyNWTw==")), + Literal("1234567890".getBytes)) + val expr4 = AesDecrypt(UnBase64(Literal("y6Ss+zCsdYObpCbgfWfyNW3Twewr")), + Literal("1234567890123456".getBytes)) + val expr5 = AesDecrypt(UnBase64(Literal("t6Ss+zCYObpCbgfWfyNWTw==")), + Literal("1234567890123456".getBytes)) + + // key length (80 bits) is not one of the permitted values (128, 192 or 256 bits) + intercept[java.security.InvalidKeyException] { + evaluate(expr3) + } + intercept[java.security.InvalidKeyException] { + UnsafeProjection.create(expr3 :: Nil).apply(null) + } + // input can not be decrypted + intercept[javax.crypto.IllegalBlockSizeException] { + evaluate(expr4) + } + intercept[javax.crypto.IllegalBlockSizeException] { + UnsafeProjection.create(expr4 :: Nil).apply(null) + } + // input can not be decrypted + intercept[javax.crypto.BadPaddingException] { + evaluate(expr5) + } + intercept[javax.crypto.BadPaddingException] { + UnsafeProjection.create(expr5 :: Nil).apply(null) + } + } + + ignore("aesEncryptWith256bitsKey") { + // Before testing this, installing Java Cryptography Extension (JCE) Unlimited Strength Juris- + // diction Policy Files first. Otherwise `java.security.InvalidKeyException` will be thrown. + // Because Oracle JDK does not support 192 and 256 bits key out of box. + checkEvaluation(Base64(AesEncrypt(Literal("ABC".getBytes), + Literal("12345678901234561234567890123456".getBytes))), "nYfCuJeRd5eD60yXDw7WEA==") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 97c6992e18753055f0544b00908e24b8b0ce0b4f..8da50bedfc8cb794b54d4a08aacd65cd16ae725b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1982,6 +1982,42 @@ object functions extends LegacyFunctions { new Murmur3Hash(cols.map(_.expr)) } + /** + * Encrypts input using AES and Returns the result as a binary column. + * Key lengths of 128, 192 or 256 bits can be used. 192 and 256 bits keys can be used if Java + * Cryptography Extension (JCE) Unlimited Strength Jurisdiction Policy Files are installed. If + * either argument is NULL, the result will also be null. If input is invalid, key length is not + * one of the permitted values or using 192/256 bits key before installing JCE, an exception will + * be thrown. + * + * @param input binary column to encrypt input + * @param key binary column of 128, 192 or 256 bits key + * + * @group misc_funcs + * @since 2.0.0 + */ + def aes_encrypt(input: Column, key: Column): Column = withExpr { + AesEncrypt(input.expr, key.expr) + } + + /** + * Decrypts input using AES and Returns the result as a string column. + * Key lengths of 128, 192 or 256 bits can be used. 192 and 256 bits keys can be used if Java + * Cryptography Extension (JCE) Unlimited Strength Jurisdiction Policy Files are installed. If + * either argument is NULL, the result will also be null. If input is invalid, key length is not + * one of the permitted values or using 192/256 bits key before installing JCE, an exception will + * be thrown. + * + * @param input binary column to decrypt input + * @param key binary column of 128, 192 or 256 bits key + * + * @group misc_funcs + * @since 2.0.0 + */ + def aes_decrypt(input: Column, key: Column): Column = withExpr { + AesDecrypt(input.expr, key.expr) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index aff9efe4b2b16aa27325a1a87288975cbb7801b9..0381d5728077bc0c85606618bacc9b72bafa8e56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -206,6 +206,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(2743272264L, 2180413220L)) } + test("misc aes encrypt function") { + val df = Seq(("ABC", "1234567890123456")).toDF("input", "key") + checkAnswer( + df.select(base64(aes_encrypt($"input", $"key"))), + Row("y6Ss+zCYObpCbgfWfyNWTw==") + ) + checkAnswer( + sql("SELECT base64(aes_encrypt('', '1234567890123456'))"), + Row("BQGHoM3lqYcsurCRq3PlUw==") + ) + } + + test("misc aes decrypt function") { + val df = Seq(("y6Ss+zCYObpCbgfWfyNWTw==", "1234567890123456")).toDF("input", "key") + checkAnswer( + df.select((aes_decrypt(unbase64($"input"), $"key"))), + Row("ABC") + ) + checkAnswer( + sql("SELECT aes_decrypt(unbase64('BQGHoM3lqYcsurCRq3PlUw=='), '1234567890123456')"), + Row("") + ) + } + test("string function find_in_set") { val df = Seq(("abc,b,ab,c,def", "abc,b,ab,c,def")).toDF("a", "b")