You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/02/20 07:28:53 UTC

spark git commit: [SPARK-12567] [SQL] Add aes_{encrypt,decrypt} UDFs

Repository: spark
Updated Branches:
  refs/heads/master ec7a1d6e4 -> 4f9a66481


[SPARK-12567] [SQL] Add aes_{encrypt,decrypt} UDFs

Author: Kai Jiang <ji...@gmail.com>

Closes #10527 from vectorijk/spark-12567.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4f9a6648
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4f9a6648
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4f9a6648

Branch: refs/heads/master
Commit: 4f9a66481849dc867cf6592d53e0e9782361d20a
Parents: ec7a1d6
Author: Kai Jiang <ji...@gmail.com>
Authored: Fri Feb 19 22:28:47 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Fri Feb 19 22:28:47 2016 -0800

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 37 ++++++++
 .../catalyst/analysis/FunctionRegistry.scala    |  2 +
 .../spark/sql/catalyst/expressions/misc.scala   | 89 ++++++++++++++++++++
 .../expressions/MiscFunctionsSuite.scala        | 84 ++++++++++++++++++
 .../scala/org/apache/spark/sql/functions.scala  | 36 ++++++++
 .../spark/sql/DataFrameFunctionsSuite.scala     | 24 ++++++
 6 files changed, 272 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4f9a6648/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 5fc1cc2..7d038b8 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1125,6 +1125,43 @@ def hash(*cols):
     return Column(jc)
 
 
+@ignore_unicode_prefix
+@since(2.0)
+def aes_encrypt(input, key):
+    """
+    Encrypts input of given column using AES. Key lengths of 128, 192 or 256 bits can be used. 192
+    and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic-
+    tion Policy Files are installed. If input is invalid, key length is not one of the permitted
+    values or using 192/256 bits key before installing JCE, an exception will be thrown.
+
+    >>> df = sqlContext.createDataFrame([('ABC','1234567890123456')], ['input','key'])
+    >>> df.select(base64(aes_encrypt(df.input, df.key)).alias('aes')).collect()
+    [Row(aes=u'y6Ss+zCYObpCbgfWfyNWTw==')]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.aes_encrypt(_to_java_column(input), _to_java_column(key))
+    return Column(jc)
+
+
+@ignore_unicode_prefix
+@since(2.0)
+def aes_decrypt(input, key):
+    """
+    Decrypts input of given column using AES. Key lengths of 128, 192 or 256 bits can be used. 192
+    and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic-
+    tion Policy Files are installed. If input is invalid, key length is not one of the permitted
+    values or using 192/256 bits key before installing JCE, an exception will be thrown.
+
+    >>> df = sqlContext.createDataFrame([(u'y6Ss+zCYObpCbgfWfyNWTw==','1234567890123456')], \
+    ['input','key'])
+    >>> df.select(aes_decrypt(unbase64(df.input), df.key).alias('aes')).collect()
+    [Row(aes=u'ABC')]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.aes_decrypt(_to_java_column(input), _to_java_column(key))
+    return Column(jc)
+
+
 # ---------------------- String/Binary functions ------------------------------
 
 _string_functions = {

http://git-wip-us.apache.org/repos/asf/spark/blob/4f9a6648/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 1be97c7..ae09c3d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -278,6 +278,8 @@ object FunctionRegistry {
     expression[ArrayContains]("array_contains"),
 
     // misc functions
+    expression[AesEncrypt]("aes_encrypt"),
+    expression[AesDecrypt]("aes_decrypt"),
     expression[Crc32]("crc32"),
     expression[Md5]("md5"),
     expression[Murmur3Hash]("hash"),

http://git-wip-us.apache.org/repos/asf/spark/blob/4f9a6648/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index dcbb594..3b66f57 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
 
 import java.security.{MessageDigest, NoSuchAlgorithmException}
 import java.util.zip.CRC32
+import javax.crypto.Cipher
+import javax.crypto.spec.SecretKeySpec
 
 import org.apache.commons.codec.digest.DigestUtils
 
@@ -441,3 +443,90 @@ case class PrintToStderr(child: Expression) extends UnaryExpression {
        """.stripMargin)
   }
 }
+
+/**
+ * A function that encrypts input using AES. Key lengths of 128, 192 or 256 bits can be used. 192
+ * and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic-
+ * tion Policy Files are installed. If either argument is NULL, the result will also be null. If
+ * input is invalid, key length is not one of the permitted values or using 192/256 bits key before
+ * installing JCE, an exception will be thrown.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(input, key) - Encrypts input using AES.",
+  extended = "> SELECT Base64(_FUNC_('ABC', '1234567890123456'));\n 'y6Ss+zCYObpCbgfWfyNWTw=='")
+case class AesEncrypt(left: Expression, right: Expression)
+  extends BinaryExpression with ImplicitCastInputTypes {
+
+  override def dataType: DataType = BinaryType
+  override def inputTypes: Seq[DataType] = Seq(BinaryType, BinaryType)
+
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val cipher = Cipher.getInstance("AES")
+    val secretKey: SecretKeySpec = new SecretKeySpec(input2.asInstanceOf[Array[Byte]], 0,
+      input2.asInstanceOf[Array[Byte]].length, "AES")
+    cipher.init(Cipher.ENCRYPT_MODE, secretKey)
+    cipher.doFinal(input1.asInstanceOf[Array[Byte]], 0, input1.asInstanceOf[Array[Byte]].length)
+  }
+
+  override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
+    nullSafeCodeGen(ctx, ev, (str, key) => {
+      val Cipher = "javax.crypto.Cipher"
+      val SecretKeySpec = "javax.crypto.spec.SecretKeySpec"
+      s"""
+          try {
+            $Cipher cipher = $Cipher.getInstance("AES");
+            $SecretKeySpec secret = new $SecretKeySpec($key, 0, $key.length, "AES");
+            cipher.init($Cipher.ENCRYPT_MODE, secret);
+            ${ev.value} = cipher.doFinal($str, 0, $str.length);
+          } catch (java.security.GeneralSecurityException e) {
+            org.apache.spark.unsafe.Platform.throwException(e);
+          }
+      """
+    })
+  }
+}
+
+/**
+ * A function that decrypts input using AES. Key lengths of 128, 192 or 256 bits can be used. 192
+ * and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic-
+ * tion Policy Files are installed. If either argument is NULL, the result will also be null. If
+ * input is invalid, key length is not one of the permitted values or using 192/256 bits key before
+ * installing JCE, an exception will be thrown.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(input, key) - Decrypts input using AES.",
+  extended = "> SELECT _FUNC_(UnBase64('y6Ss+zCYObpCbgfWfyNWTw=='),'1234567890123456');\n 'ABC'")
+case class AesDecrypt(left: Expression, right: Expression)
+  extends BinaryExpression with ImplicitCastInputTypes {
+
+  override def dataType: DataType = StringType
+  override def inputTypes: Seq[DataType] = Seq(BinaryType, BinaryType)
+
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val cipher = Cipher.getInstance("AES")
+    val secretKey = new SecretKeySpec(input2.asInstanceOf[Array[Byte]], 0,
+      input2.asInstanceOf[Array[Byte]].length, "AES")
+
+    cipher.init(Cipher.DECRYPT_MODE, secretKey)
+    UTF8String.fromBytes(
+      cipher.doFinal(input1.asInstanceOf[Array[Byte]], 0,
+        input1.asInstanceOf[Array[Byte]].length))
+  }
+
+  override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
+    nullSafeCodeGen(ctx, ev, (str, key) => {
+      val Cipher = "javax.crypto.Cipher"
+      val SecretKeySpec = "javax.crypto.spec.SecretKeySpec"
+      s"""
+          try {
+            $Cipher cipher = $Cipher.getInstance("AES");
+            $SecretKeySpec secret = new $SecretKeySpec($key, 0, $key.length, "AES");
+            cipher.init($Cipher.DECRYPT_MODE, secret);
+            ${ev.value} = UTF8String.fromBytes(cipher.doFinal($str, 0, $str.length));
+          } catch (java.security.GeneralSecurityException e) {
+            org.apache.spark.unsafe.Platform.throwException(e);
+          }
+      """
+    })
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4f9a6648/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
index 75131a6..67f2dc4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
@@ -132,4 +132,88 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       }
     }
   }
+
+  test("aesEncrypt") {
+    val expr1 = AesEncrypt(Literal("ABC".getBytes), Literal("1234567890123456".getBytes))
+    val expr2 = AesEncrypt(Literal("".getBytes), Literal("1234567890123456".getBytes))
+
+    checkEvaluation(Base64(expr1), "y6Ss+zCYObpCbgfWfyNWTw==")
+    checkEvaluation(Base64(expr2), "BQGHoM3lqYcsurCRq3PlUw==")
+
+    // input is null
+    checkEvaluation(AesEncrypt(Literal.create(null, BinaryType),
+      Literal("1234567890123456".getBytes)), null)
+    // key is null
+    checkEvaluation(AesEncrypt(Literal("ABC".getBytes),
+      Literal.create(null, BinaryType)), null)
+    // both are null
+    checkEvaluation(AesEncrypt(Literal.create(null, BinaryType),
+      Literal.create(null, BinaryType)), null)
+
+    val expr3 = AesEncrypt(Literal("ABC".getBytes), Literal("1234567890".getBytes))
+    // key length (80 bits) is not one of the permitted values (128, 192 or 256 bits)
+    intercept[java.security.InvalidKeyException] {
+      evaluate(expr3)
+    }
+    intercept[java.security.InvalidKeyException] {
+      UnsafeProjection.create(expr3 :: Nil).apply(null)
+    }
+  }
+
+  test("aesDecrypt") {
+    val expr1 = AesDecrypt(UnBase64(Literal("y6Ss+zCYObpCbgfWfyNWTw==")),
+      Literal("1234567890123456".getBytes))
+    val expr2 = AesDecrypt(UnBase64(Literal("BQGHoM3lqYcsurCRq3PlUw==")),
+      Literal("1234567890123456".getBytes))
+
+    checkEvaluation(expr1, "ABC")
+    checkEvaluation(expr2, "")
+
+    // input is null
+    checkEvaluation(AesDecrypt(UnBase64(Literal.create(null, StringType)),
+      Literal("1234567890123456".getBytes)), null)
+    // key is null
+    checkEvaluation(AesDecrypt(UnBase64(Literal("y6Ss+zCYObpCbgfWfyNWTw==")),
+      Literal.create(null, BinaryType)), null)
+    // both are null
+    checkEvaluation(AesDecrypt(UnBase64(Literal.create(null, StringType)),
+      Literal.create(null, BinaryType)), null)
+
+    val expr3 = AesDecrypt(UnBase64(Literal("y6Ss+zCYObpCbgfWfyNWTw==")),
+      Literal("1234567890".getBytes))
+    val expr4 = AesDecrypt(UnBase64(Literal("y6Ss+zCsdYObpCbgfWfyNW3Twewr")),
+      Literal("1234567890123456".getBytes))
+    val expr5 = AesDecrypt(UnBase64(Literal("t6Ss+zCYObpCbgfWfyNWTw==")),
+      Literal("1234567890123456".getBytes))
+
+    // key length (80 bits) is not one of the permitted values (128, 192 or 256 bits)
+    intercept[java.security.InvalidKeyException] {
+      evaluate(expr3)
+    }
+    intercept[java.security.InvalidKeyException] {
+      UnsafeProjection.create(expr3 :: Nil).apply(null)
+    }
+    // input can not be decrypted
+    intercept[javax.crypto.IllegalBlockSizeException] {
+      evaluate(expr4)
+    }
+    intercept[javax.crypto.IllegalBlockSizeException] {
+      UnsafeProjection.create(expr4 :: Nil).apply(null)
+    }
+    // input can not be decrypted
+    intercept[javax.crypto.BadPaddingException] {
+      evaluate(expr5)
+    }
+    intercept[javax.crypto.BadPaddingException] {
+      UnsafeProjection.create(expr5 :: Nil).apply(null)
+    }
+  }
+
+  ignore("aesEncryptWith256bitsKey") {
+    // Before testing this, installing Java Cryptography Extension (JCE) Unlimited Strength Juris-
+    // diction Policy Files first. Otherwise `java.security.InvalidKeyException` will be thrown.
+    // Because Oracle JDK does not support 192 and 256 bits key out of box.
+    checkEvaluation(Base64(AesEncrypt(Literal("ABC".getBytes),
+      Literal("12345678901234561234567890123456".getBytes))), "nYfCuJeRd5eD60yXDw7WEA==")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4f9a6648/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 97c6992..8da50be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1982,6 +1982,42 @@ object functions extends LegacyFunctions {
     new Murmur3Hash(cols.map(_.expr))
   }
 
+  /**
+   * Encrypts input using AES and Returns the result as a binary column.
+   * Key lengths of 128, 192 or 256 bits can be used. 192 and 256 bits keys can be used if Java
+   * Cryptography Extension (JCE) Unlimited Strength Jurisdiction Policy Files are installed. If
+   * either argument is NULL, the result will also be null. If input is invalid, key length is not
+   * one of the permitted values or using 192/256 bits key before installing JCE, an exception will
+   * be thrown.
+   *
+   * @param input binary column to encrypt input
+   * @param key binary column of 128, 192 or 256 bits key
+   *
+   * @group misc_funcs
+   * @since 2.0.0
+   */
+  def aes_encrypt(input: Column, key: Column): Column = withExpr {
+    AesEncrypt(input.expr, key.expr)
+  }
+
+  /**
+   * Decrypts input using AES and Returns the result as a string column.
+   * Key lengths of 128, 192 or 256 bits can be used. 192 and 256 bits keys can be used if Java
+   * Cryptography Extension (JCE) Unlimited Strength Jurisdiction Policy Files are installed. If
+   * either argument is NULL, the result will also be null. If input is invalid, key length is not
+   * one of the permitted values or using 192/256 bits key before installing JCE, an exception will
+   * be thrown.
+   *
+   * @param input binary column to decrypt input
+   * @param key binary column of 128, 192 or 256 bits key
+   *
+   * @group misc_funcs
+   * @since 2.0.0
+   */
+  def aes_decrypt(input: Column, key: Column): Column = withExpr {
+    AesDecrypt(input.expr, key.expr)
+  }
+
   //////////////////////////////////////////////////////////////////////////////////////////////
   // String functions
   //////////////////////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/4f9a6648/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index aff9efe..0381d57 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -206,6 +206,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
       Row(2743272264L, 2180413220L))
   }
 
+  test("misc aes encrypt function") {
+    val df = Seq(("ABC", "1234567890123456")).toDF("input", "key")
+    checkAnswer(
+      df.select(base64(aes_encrypt($"input", $"key"))),
+      Row("y6Ss+zCYObpCbgfWfyNWTw==")
+    )
+    checkAnswer(
+      sql("SELECT base64(aes_encrypt('', '1234567890123456'))"),
+      Row("BQGHoM3lqYcsurCRq3PlUw==")
+    )
+  }
+
+  test("misc aes decrypt function") {
+    val df = Seq(("y6Ss+zCYObpCbgfWfyNWTw==", "1234567890123456")).toDF("input", "key")
+    checkAnswer(
+      df.select((aes_decrypt(unbase64($"input"), $"key"))),
+      Row("ABC")
+    )
+    checkAnswer(
+      sql("SELECT aes_decrypt(unbase64('BQGHoM3lqYcsurCRq3PlUw=='), '1234567890123456')"),
+      Row("")
+    )
+  }
+
   test("string function find_in_set") {
     val df = Seq(("abc,b,ab,c,def", "abc,b,ab,c,def")).toDF("a", "b")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org