You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2023/06/16 12:42:21 UTC

[spark] branch master updated: [SPARK-43290][SQL] Adds support for aes_encrypt IVs and AAD

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fb1ee25a89e [SPARK-43290][SQL] Adds support for aes_encrypt IVs and AAD
fb1ee25a89e is described below

commit fb1ee25a89e8b42178b7f55718859ab5117c2320
Author: Steve Weis <st...@databricks.com>
AuthorDate: Fri Jun 16 15:42:05 2023 +0300

    [SPARK-43290][SQL] Adds support for aes_encrypt IVs and AAD
    
    ### What changes were proposed in this pull request?
    
    This change adds support for user-provided initialization vectors (IVs) or authenticated additional data (AAD) to `aes_encrypt` / `aes_decrypt`. 12-byte IVs may optionally be passed if the mode is "GCM" and 16-byte IVs may be passed if the mode is "CBC". An arbitrary binary value may be passed as additional authenticated data only if "GCM" mode is used.
    
    ### Why are the changes needed?
    
    Callers may wish to provide their own IV values so that the output ciphertext matches a ciphertext generated outside of Spark. AAD is used to bind some input to a ciphertext and ensure that it is presented during decryption -- often used to scope an operation to a specific context.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, this change introduces two optional parameters to `aes_encrypt` and one optional parameter to `aes_decrypt`:
    ```
    aes_encrypt(expr, key[, mode[, padding[, iv[, aad]]]])
    aes_decrypt(expr, key[, mode[, padding[, iv]]])
    ```
    
    ### How was this patch tested?
    
    ```
    build/sbt "sql/test:testOnly org.apache.spark.sql.DataFrameFunctionsSuite -- -z aes"
    ```
    
    Closes #41488 from sweisdb/SPARK-43290.
    
    Authored-by: Steve Weis <st...@databricks.com>
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 .../catalyst/expressions/ExpressionImplUtils.java  | 14 +----
 .../spark/sql/catalyst/expressions/misc.scala      | 64 +++++++++++++++++-----
 .../expressions/ExpressionImplUtilsSuite.scala     | 23 +++++++-
 .../sql-functions/sql-expression-schema.md         |  6 +-
 .../apache/spark/sql/DataFrameFunctionsSuite.scala | 50 +++++++++++++++++
 5 files changed, 127 insertions(+), 30 deletions(-)

diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
index 6aae649718a..a604e6bf225 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
@@ -111,14 +111,6 @@ public class ExpressionImplUtils {
     return checkSum % 10 == 0;
   }
 
-  public static byte[] aesEncrypt(byte[] input, byte[] key, UTF8String mode, UTF8String padding) {
-    return aesEncrypt(input, key, mode, padding, null, null);
-  }
-
-  public static byte[] aesDecrypt(byte[] input, byte[] key, UTF8String mode, UTF8String padding) {
-    return aesDecrypt(input, key, mode, padding, null);
-  }
-
   public static byte[] aesEncrypt(byte[] input,
                                   byte[] key,
                                   UTF8String mode,
@@ -192,7 +184,7 @@ public class ExpressionImplUtils {
       Cipher cipher = Cipher.getInstance(cipherMode.transformation);
       if (opmode == Cipher.ENCRYPT_MODE) {
         // This may be 0-length for ECB
-        if (iv == null) {
+        if (iv == null || iv.length == 0) {
           iv = generateIv(cipherMode);
         } else if (!cipherMode.usesSpec) {
           // If the caller passes an IV, ensure the mode actually uses it.
@@ -210,7 +202,7 @@ public class ExpressionImplUtils {
         }
 
         // If the cipher mode supports additional authenticated data and it is provided, update it
-        if (aad != null) {
+        if (aad != null && aad.length != 0) {
           if (cipherMode.supportsAad != true) {
             throw QueryExecutionErrors.aesUnsupportedAad(mode);
           }
@@ -231,7 +223,7 @@ public class ExpressionImplUtils {
         if (cipherMode.usesSpec) {
           AlgorithmParameterSpec algSpec = getParamSpec(cipherMode, input);
           cipher.init(opmode, secretKey, algSpec);
-          if (aad != null) {
+          if (aad != null && aad.length != 0) {
             if (cipherMode.supportsAad != true) {
               throw QueryExecutionErrors.aesUnsupportedAad(mode);
             }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 67328cde71a..92ed0843521 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -312,8 +312,10 @@ case class CurrentUser() extends LeafExpression with Unevaluable {
 // scalastyle:off line.size.limit
 @ExpressionDescription(
   usage = """
-    _FUNC_(expr, key[, mode[, padding]]) - Returns an encrypted value of `expr` using AES in given `mode` with the specified `padding`.
+    _FUNC_(expr, key[, mode[, padding[, iv[, aad]]]]) - Returns an encrypted value of `expr` using AES in given `mode` with the specified `padding`.
       Key lengths of 16, 24 and 32 bits are supported. Supported combinations of (`mode`, `padding`) are ('ECB', 'PKCS'), ('GCM', 'NONE') and ('CBC', 'PKCS').
+      Optional initialization vectors (IVs) are only supported for CBC and GCM modes. These must be 16 bytes for CBC and 12 bytes for GCM. If not provided, a random vector will be generated and prepended to the output.
+      Optional additional authenticated data (AAD) is only supported for GCM. If provided for encryption, the identical AAD value must be provided for decryption.
       The default mode is GCM.
   """,
   arguments = """
@@ -324,6 +326,10 @@ case class CurrentUser() extends LeafExpression with Unevaluable {
                Valid modes: ECB, GCM, CBC.
       * padding - Specifies how to pad messages whose length is not a multiple of the block size.
                   Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means PKCS for ECB, NONE for GCM and PKCS for CBC.
+      * iv - Optional initialization vector. Only supported for CBC and GCM modes.
+             Valid values: None or ''. 16-byte array for CBC mode. 12-byte array for GCM mode.
+      * aad - Optional additional authenticated data. Only supported for GCM mode. This can be any free-form input and
+              must be provided for both encryption and decryption.
   """,
   examples = """
     Examples:
@@ -335,6 +341,10 @@ case class CurrentUser() extends LeafExpression with Unevaluable {
        3lmwu+Mw0H3fi5NDvcu9lg==
       > SELECT base64(_FUNC_('Apache Spark', '1234567890abcdef', 'CBC', 'DEFAULT'));
        2NYmDCjgXTbbxGA3/SnJEfFC/JQ7olk2VQWReIAAFKo=
+      > SELECT base64(_FUNC_('Spark', 'abcdefghijklmnop12345678ABCDEFGH', 'CBC', 'DEFAULT', unhex('00000000000000000000000000000000')));
+       AAAAAAAAAAAAAAAAAAAAAPSd4mWyMZ5mhvjiAPQJnfg=
+      > SELECT base64(_FUNC_('Spark', 'abcdefghijklmnop12345678ABCDEFGH', 'GCM', 'DEFAULT', unhex('000000000000000000000000'), 'This is an AAD mixed into the input'));
+       AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4
   """,
   since = "3.3.0",
   group = "misc_funcs")
@@ -342,16 +352,22 @@ case class AesEncrypt(
     input: Expression,
     key: Expression,
     mode: Expression,
-    padding: Expression)
+    padding: Expression,
+    iv: Expression,
+    aad: Expression)
   extends RuntimeReplaceable with ImplicitCastInputTypes {
 
   override lazy val replacement: Expression = StaticInvoke(
     classOf[ExpressionImplUtils],
     BinaryType,
     "aesEncrypt",
-    Seq(input, key, mode, padding),
+    Seq(input, key, mode, padding, iv, aad),
     inputTypes)
 
+  def this(input: Expression, key: Expression, mode: Expression, padding: Expression, iv: Expression) =
+    this(input, key, mode, padding, iv, Literal(""))
+  def this(input: Expression, key: Expression, mode: Expression, padding: Expression) =
+    this(input, key, mode, padding, Literal(""))
   def this(input: Expression, key: Expression, mode: Expression) =
     this(input, key, mode, Literal("DEFAULT"))
   def this(input: Expression, key: Expression) =
@@ -359,13 +375,14 @@ case class AesEncrypt(
 
   override def prettyName: String = "aes_encrypt"
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, BinaryType, StringType, StringType)
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType)
 
-  override def children: Seq[Expression] = Seq(input, key, mode, padding)
+  override def children: Seq[Expression] = Seq(input, key, mode, padding, iv, aad)
 
   override protected def withNewChildrenInternal(
       newChildren: IndexedSeq[Expression]): Expression = {
-    copy(newChildren(0), newChildren(1), newChildren(2), newChildren(3))
+    copy(newChildren(0), newChildren(1), newChildren(2), newChildren(3), newChildren(4), newChildren(5))
   }
 }
 
@@ -378,8 +395,9 @@ case class AesEncrypt(
  */
 @ExpressionDescription(
   usage = """
-    _FUNC_(expr, key[, mode[, padding]]) - Returns a decrypted value of `expr` using AES in `mode` with `padding`.
+    _FUNC_(expr, key[, mode[, padding[, aad]]]) - Returns a decrypted value of `expr` using AES in `mode` with `padding`.
       Key lengths of 16, 24 and 32 bits are supported. Supported combinations of (`mode`, `padding`) are ('ECB', 'PKCS'), ('GCM', 'NONE') and ('CBC', 'PKCS').
+      Optional additional authenticated data (AAD) is only supported for GCM. If provided for encryption, the identical AAD value must be provided for decryption.
       The default mode is GCM.
   """,
   arguments = """
@@ -390,6 +408,8 @@ case class AesEncrypt(
                Valid modes: ECB, GCM, CBC.
       * padding - Specifies how to pad messages whose length is not a multiple of the block size.
                   Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means PKCS for ECB, NONE for GCM and PKCS for CBC.
+      * aad - Optional additional authenticated data. Only supported for GCM mode. This can be any free-form input and
+              must be provided for both encryption and decryption.
   """,
   examples = """
     Examples:
@@ -401,6 +421,10 @@ case class AesEncrypt(
        Spark SQL
       > SELECT _FUNC_(unbase64('2NYmDCjgXTbbxGA3/SnJEfFC/JQ7olk2VQWReIAAFKo='), '1234567890abcdef', 'CBC');
        Apache Spark
+      > SELECT _FUNC_(unbase64('AAAAAAAAAAAAAAAAAAAAAPSd4mWyMZ5mhvjiAPQJnfg='), 'abcdefghijklmnop12345678ABCDEFGH', 'CBC', 'DEFAULT');
+       Spark
+      > SELECT _FUNC_(unbase64('AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4'), 'abcdefghijklmnop12345678ABCDEFGH', 'GCM', 'DEFAULT', 'This is an AAD mixed into the input');
+       Spark
   """,
   since = "3.3.0",
   group = "misc_funcs")
@@ -408,37 +432,40 @@ case class AesDecrypt(
     input: Expression,
     key: Expression,
     mode: Expression,
-    padding: Expression)
+    padding: Expression,
+    aad: Expression)
   extends RuntimeReplaceable with ImplicitCastInputTypes {
 
   override lazy val replacement: Expression = StaticInvoke(
     classOf[ExpressionImplUtils],
     BinaryType,
     "aesDecrypt",
-    Seq(input, key, mode, padding),
+    Seq(input, key, mode, padding, aad),
     inputTypes)
 
+  def this(input: Expression, key: Expression, mode: Expression, padding: Expression) =
+    this(input, key, mode, padding, Literal(""))
   def this(input: Expression, key: Expression, mode: Expression) =
     this(input, key, mode, Literal("DEFAULT"))
   def this(input: Expression, key: Expression) =
     this(input, key, Literal("GCM"))
 
   override def inputTypes: Seq[AbstractDataType] = {
-    Seq(BinaryType, BinaryType, StringType, StringType)
+    Seq(BinaryType, BinaryType, StringType, StringType, BinaryType)
   }
 
   override def prettyName: String = "aes_decrypt"
 
-  override def children: Seq[Expression] = Seq(input, key, mode, padding)
+  override def children: Seq[Expression] = Seq(input, key, mode, padding, aad)
 
   override protected def withNewChildrenInternal(
       newChildren: IndexedSeq[Expression]): Expression = {
-    copy(newChildren(0), newChildren(1), newChildren(2), newChildren(3))
+    copy(newChildren(0), newChildren(1), newChildren(2), newChildren(3), newChildren(4))
   }
 }
 
 @ExpressionDescription(
-  usage = "_FUNC_(expr, key[, mode[, padding]]) - This is a special version of `aes_decrypt` that performs the same operation, but returns a NULL value instead of raising an error if the decryption cannot be performed.",
+  usage = "_FUNC_(expr, key[, mode[, padding[, aad]]]) - This is a special version of `aes_decrypt` that performs the same operation, but returns a NULL value instead of raising an error if the decryption cannot be performed.",
   examples = """
     Examples:
       > SELECT _FUNC_(unhex('6E7CA17BBB468D3084B5744BCA729FB7B2B7BCB8E4472847D02670489D95FA97DBBA7D3210'), '0000111122223333', 'GCM');
@@ -454,10 +481,17 @@ case class TryAesDecrypt(
     key: Expression,
     mode: Expression,
     padding: Expression,
+    aad: Expression,
     replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules {
 
+  def this(input: Expression,
+           key: Expression,
+           mode: Expression,
+           padding: Expression,
+           aad: Expression) =
+    this(input, key, mode, padding, aad, TryEval(AesDecrypt(input, key, mode, padding, aad)))
   def this(input: Expression, key: Expression, mode: Expression, padding: Expression) =
-    this(input, key, mode, padding, TryEval(AesDecrypt(input, key, mode, padding)))
+    this(input, key, mode, padding, Literal(""))
   def this(input: Expression, key: Expression, mode: Expression) =
     this(input, key, mode, Literal("DEFAULT"))
   def this(input: Expression, key: Expression) =
@@ -465,7 +499,7 @@ case class TryAesDecrypt(
 
   override def prettyName: String = "try_aes_decrypt"
 
-  override def parameters: Seq[Expression] = Seq(input, key, mode, padding)
+  override def parameters: Seq[Expression] = Seq(input, key, mode, padding, aad)
 
   override protected def withNewChildInternal(newChild: Expression): Expression =
     this.copy(replacement = newChild)
diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala
index 52258156e31..3b0dd82c173 100644
--- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala
+++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala
@@ -34,11 +34,16 @@ class ExpressionImplUtilsSuite extends SparkFunSuite {
     aadOpt: Option[String] = None,
     expectedErrorClassOpt: Option[String] = None,
     errorParamsMap: Map[String, String] = Map()) {
+
+    def isIvDefined: Boolean = {
+      ivHexOpt.isDefined && ivHexOpt.get != null && ivHexOpt.get.length > 0
+    }
+
     val plaintextBytes: Array[Byte] = plaintext.getBytes("UTF-8")
     val keyBytes: Array[Byte] = key.getBytes("UTF-8")
     val utf8mode: UTF8String = UTF8String.fromString(mode)
     val utf8Padding: UTF8String = UTF8String.fromString(padding)
-    val deterministic: Boolean = mode.equalsIgnoreCase("ECB") || ivHexOpt.isDefined
+    val deterministic: Boolean = mode.equalsIgnoreCase("ECB") || isIvDefined
     val ivBytes: Array[Byte] =
       ivHexOpt.map({ivHex => Hex.unhex(ivHex.getBytes("UTF-8"))}).getOrElse(null)
     val aadBytes: Array[Byte] = aadOpt.map({aad => aad.getBytes("UTF-8")}).getOrElse(null)
@@ -59,11 +64,27 @@ class ExpressionImplUtilsSuite extends SparkFunSuite {
       "abcdefghijklmnop12345678ABCDEFGH",
       "9J3iZbIxnmaG+OIA9Amd+A==",
       "ECB"),
+    // Test passing non-null, but empty arrays for IV and AAD
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "9J3iZbIxnmaG+OIA9Amd+A==",
+      "ECB",
+      ivHexOpt = Some(""),
+      aadOpt = Some("")),
     TestCase(
       "Spark",
       "abcdefghijklmnop12345678ABCDEFGH",
       "+MgyzJxhusYVGWCljk7fhhl6C6oUqWmtdqoaG93KvhY=",
       "CBC"),
+    // Test passing non-null, but empty arrays for IV and AAD
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "+MgyzJxhusYVGWCljk7fhhl6C6oUqWmtdqoaG93KvhY=",
+      "CBC",
+      ivHexOpt = Some(""),
+      aadOpt = Some("")),
     TestCase(
       "Apache Spark",
       "1234567890abcdef",
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index eef61195357..32c4c02b1b2 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -7,8 +7,8 @@
 | org.apache.spark.sql.catalyst.expressions.Acosh | acosh | SELECT acosh(1) | struct<ACOSH(1):double> |
 | org.apache.spark.sql.catalyst.expressions.Add | + | SELECT 1 + 2 | struct<(1 + 2):int> |
 | org.apache.spark.sql.catalyst.expressions.AddMonths | add_months | SELECT add_months('2016-08-31', 1) | struct<add_months(2016-08-31, 1):date> |
-| org.apache.spark.sql.catalyst.expressions.AesDecrypt | aes_decrypt | SELECT aes_decrypt(unhex('83F16B2AA704794132802D248E6BFD4E380078182D1544813898AC97E709B28A94'), '0000111122223333') | struct<aes_decrypt(unhex(83F16B2AA704794132802D248E6BFD4E380078182D1544813898AC97E709B28A94), 0000111122223333, GCM, DEFAULT):binary> |
-| org.apache.spark.sql.catalyst.expressions.AesEncrypt | aes_encrypt | SELECT hex(aes_encrypt('Spark', '0000111122223333')) | struct<hex(aes_encrypt(Spark, 0000111122223333, GCM, DEFAULT)):string> |
+| org.apache.spark.sql.catalyst.expressions.AesDecrypt | aes_decrypt | SELECT aes_decrypt(unhex('83F16B2AA704794132802D248E6BFD4E380078182D1544813898AC97E709B28A94'), '0000111122223333') | struct<aes_decrypt(unhex(83F16B2AA704794132802D248E6BFD4E380078182D1544813898AC97E709B28A94), 0000111122223333, GCM, DEFAULT, ):binary> |
+| org.apache.spark.sql.catalyst.expressions.AesEncrypt | aes_encrypt | SELECT hex(aes_encrypt('Spark', '0000111122223333')) | struct<hex(aes_encrypt(Spark, 0000111122223333, GCM, DEFAULT, , )):string> |
 | org.apache.spark.sql.catalyst.expressions.And | and | SELECT true and true | struct<(true AND true):boolean> |
 | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct<aggregate(array(1, 2, 3), 0, lambdafunction((namedlambdavariable() + namedlambdavariable()), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable())):int> |
 | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct<reduce(array(1, 2, 3), 0, lambdafunction((namedlambdavariable() + namedlambdavariable()), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable())):int> |
@@ -331,7 +331,7 @@
 | org.apache.spark.sql.catalyst.expressions.TruncDate | trunc | SELECT trunc('2019-08-04', 'week') | struct<trunc(2019-08-04, week):date> |
 | org.apache.spark.sql.catalyst.expressions.TruncTimestamp | date_trunc | SELECT date_trunc('YEAR', '2015-03-05T09:32:05.359') | struct<date_trunc(YEAR, 2015-03-05T09:32:05.359):timestamp> |
 | org.apache.spark.sql.catalyst.expressions.TryAdd | try_add | SELECT try_add(1, 2) | struct<try_add(1, 2):int> |
-| org.apache.spark.sql.catalyst.expressions.TryAesDecrypt | try_aes_decrypt | SELECT try_aes_decrypt(unhex('6E7CA17BBB468D3084B5744BCA729FB7B2B7BCB8E4472847D02670489D95FA97DBBA7D3210'), '0000111122223333', 'GCM') | struct<try_aes_decrypt(unhex(6E7CA17BBB468D3084B5744BCA729FB7B2B7BCB8E4472847D02670489D95FA97DBBA7D3210), 0000111122223333, GCM, DEFAULT):binary> |
+| org.apache.spark.sql.catalyst.expressions.TryAesDecrypt | try_aes_decrypt | SELECT try_aes_decrypt(unhex('6E7CA17BBB468D3084B5744BCA729FB7B2B7BCB8E4472847D02670489D95FA97DBBA7D3210'), '0000111122223333', 'GCM') | struct<try_aes_decrypt(unhex(6E7CA17BBB468D3084B5744BCA729FB7B2B7BCB8E4472847D02670489D95FA97DBBA7D3210), 0000111122223333, GCM, DEFAULT, ):binary> |
 | org.apache.spark.sql.catalyst.expressions.TryDivide | try_divide | SELECT try_divide(3, 2) | struct<try_divide(3, 2):double> |
 | org.apache.spark.sql.catalyst.expressions.TryElementAt | try_element_at | SELECT try_element_at(array(1, 2, 3), 2) | struct<try_element_at(array(1, 2, 3), 2):int> |
 | org.apache.spark.sql.catalyst.expressions.TryMultiply | try_multiply | SELECT try_multiply(2, 3) | struct<try_multiply(2, 3):int> |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 037202de9c9..4d7e8cbb351 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -408,6 +408,56 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
     }
   }
 
+  test("aes IV test function") {
+    val key32 = "abcdefghijklmnop12345678ABCDEFGH"
+    val gcmIv = "000000000000000000000000"
+    val encryptedGcm = "AAAAAAAAAAAAAAAAQiYi+sRNYDAOTjdSEcYBFsAWPL1f"
+    val cbcIv = "00000000000000000000000000000000"
+    val encryptedCbc = "AAAAAAAAAAAAAAAAAAAAAPSd4mWyMZ5mhvjiAPQJnfg="
+    val df1 = Seq("Spark").toDF
+    Seq(
+      (key32, encryptedGcm, "GCM", gcmIv),
+      (key32, encryptedCbc, "CBC", cbcIv)).foreach {
+      case (key, ciphertext, mode, iv) =>
+        checkAnswer(
+          df1.selectExpr(s"cast(aes_decrypt(unbase64('$ciphertext'), " +
+              s"'$key', '$mode', 'DEFAULT') as string)"),
+          Seq(Row("Spark")))
+        checkAnswer(
+          df1.selectExpr(s"cast(aes_decrypt(unbase64('$ciphertext'), " +
+            s"binary('$key'), '$mode', 'DEFAULT') as string)"),
+          Seq(Row("Spark")))
+        checkAnswer(
+          df1.selectExpr(
+            s"base64(aes_encrypt(value, '$key32', '$mode', 'DEFAULT', unhex('$iv')))"),
+          Seq(Row(ciphertext)))
+    }
+  }
+
+  test("aes IV and AAD test function") {
+    val key32 = "abcdefghijklmnop12345678ABCDEFGH"
+    val gcmIv = "000000000000000000000000"
+    val aad = "This is an AAD mixed into the input"
+    val encryptedGcm = "AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4"
+    val df1 = Seq("Spark").toDF
+    Seq(
+      (key32, encryptedGcm, "GCM", gcmIv, aad)).foreach {
+      case (key, ciphertext, mode, iv, aad) =>
+        checkAnswer(
+          df1.selectExpr(s"cast(aes_decrypt(unbase64('$ciphertext'), " +
+            s"'$key', '$mode', 'DEFAULT', '$aad') as string)"),
+          Seq(Row("Spark")))
+        checkAnswer(
+          df1.selectExpr(s"cast(aes_decrypt(unbase64('$ciphertext'), " +
+            s"binary('$key'), '$mode', 'DEFAULT', '$aad') as string)"),
+          Seq(Row("Spark")))
+        checkAnswer(
+          df1.selectExpr(
+            s"base64(aes_encrypt(value, '$key32', '$mode', 'DEFAULT', unhex('$iv'), '$aad'))"),
+          Seq(Row(ciphertext)))
+    }
+  }
+
   test("misc aes ECB function") {
     val key16 = "abcdefghijklmnop"
     val key24 = "abcdefghijklmnop12345678"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org