This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ba2d785b994 [SPARK-43290][SQL] Adds AES IV and AAD support to 
ExpressionImplUtils
ba2d785b994 is described below

commit ba2d785b99461871f588de6a8260f3201204f313
Author: Steve Weis <steve.w...@databricks.com>
AuthorDate: Mon May 22 22:43:46 2023 +0300

    [SPARK-43290][SQL] Adds AES IV and AAD support to ExpressionImplUtils
    
    ### What changes were proposed in this pull request?
    This change adds support for optional IV and AAD fields to 
ExpressionImplUtils, which is the underlying library to support `aes_encrypt` 
and `aes_decrypt`. This allows callers to specify their own initialization 
vector values for some specific use cases, and to take advantage of AES-GCM's 
authenticated additional data optional input.
    
    This change does **not** add the support to the user-facing `aes_encrypt` 
and `aes_decrypt` yet. That will be added in a follow-up, rather than in a 
single complex change.
    
    ### Why are the changes needed?
    
    There are some use cases where callers to ExpressionImplUtils via 
aes_encrypt may want to provide initialization vectors (IVs) or additional 
authenticated data (AAD). The most common cases will be:
    1. Ensuring that ciphertext matches values that have been encrypted by 
external tools. In those cases, the caller will need to provide an identical IV 
value.
    2. For AES-CBC mode, there are some cases where callers want to generate 
deterministic encrypted output.
    3. For AES-GCM mode, providing AAD fields allows callers to bind additional 
data to an encrypted ciphertext so that it can only be decrypted by a caller 
providing the same value. This is often used to enforce some context.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Not yet. This change adds support to the underlying implementation, but 
does not yet update the SQL support to include the new parameters.
    
    ### How was this patch tested?
    
    All existing unit tests still pass and new tests in 
`ExpressionImplUtilsSuite` exercise the new code paths:
    ```
    build/sbt "sql/test:testOnly org.apache.spark.sql.DataFrameFunctionsSuite"
    build/sbt "catalyst/test:testOnly 
org.apache.spark.sql.catalyst.expressions.ExpressionImplUtilsSuite"
    ```
    
    Closes #40970 from sweisdb/SPARK-43290.
    
    Lead-authored-by: Steve Weis <steve.w...@databricks.com>
    Co-authored-by: sweisdb <60895808+swei...@users.noreply.github.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 core/src/main/resources/error/error-classes.json   |  17 +-
 .../catalyst/expressions/ExpressionImplUtils.java  |  98 ++++++--
 .../spark/sql/errors/QueryExecutionErrors.scala    |  28 ++-
 .../expressions/ExpressionImplUtilsSuite.scala     | 268 ++++++++++++++++++---
 .../sql/errors/QueryExecutionErrorsSuite.scala     |   4 +-
 5 files changed, 368 insertions(+), 47 deletions(-)

diff --git a/core/src/main/resources/error/error-classes.json 
b/core/src/main/resources/error/error-classes.json
index b5b33758341..b3023fad83b 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -1074,11 +1074,16 @@
       "The value of parameter(s) <parameter> in <functionName> is invalid:"
     ],
     "subClass" : {
-      "AES_KEY" : {
+      "AES_CRYPTO_ERROR" : {
         "message" : [
           "detail message: <detailMessage>"
         ]
       },
+      "AES_IV_LENGTH" : {
+        "message" : [
+          "supports 16-byte CBC IVs and 12-byte GCM IVs, but got 
<actualLength> bytes for <mode>."
+        ]
+      },
       "AES_KEY_LENGTH" : {
         "message" : [
           "expects a binary value with 16, 24 or 32 bytes, but got 
<actualLength> bytes."
@@ -1839,6 +1844,16 @@
           "AES-<mode> with the padding <padding> by the <functionName> 
function."
         ]
       },
+      "AES_MODE_AAD" : {
+        "message" : [
+          "<functionName> with AES-<mode> does not support additional 
authenticate data (AAD)."
+        ]
+      },
+      "AES_MODE_IV" : {
+        "message" : [
+          "<functionName> with AES-<mode> does not support initialization 
vectors (IVs)."
+        ]
+      },
       "ANALYZE_UNCACHED_TEMP_VIEW" : {
         "message" : [
           "The ANALYZE TABLE FOR COLUMNS command can operate on temporary 
views that have been cached already. Consider to cache the view <viewName>."
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
index 6843a348006..6aae649718a 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
@@ -42,33 +42,40 @@ public class ExpressionImplUtils {
   private static final int CBC_IV_LEN = 16;
 
   enum CipherMode {
-    ECB("ECB", 0, 0, "AES/ECB/PKCS5Padding", false),
-    CBC("CBC", CBC_IV_LEN, 0, "AES/CBC/PKCS5Padding", true),
-    GCM("GCM", GCM_IV_LEN, GCM_TAG_LEN, "AES/GCM/NoPadding", true);
+    ECB("ECB", 0, 0, "AES/ECB/PKCS5Padding", false, false),
+    CBC("CBC", CBC_IV_LEN, 0, "AES/CBC/PKCS5Padding", true, false),
+    GCM("GCM", GCM_IV_LEN, GCM_TAG_LEN, "AES/GCM/NoPadding", true, true);
 
     private final String name;
     final int ivLength;
     final int tagLength;
     final String transformation;
     final boolean usesSpec;
-
-    CipherMode(String name, int ivLen, int tagLen, String transformation, 
boolean usesSpec) {
+    final boolean supportsAad;
+
+    CipherMode(String name,
+               int ivLen,
+               int tagLen,
+               String transformation,
+               boolean usesSpec,
+               boolean supportsAad) {
       this.name = name;
       this.ivLength = ivLen;
       this.tagLength = tagLen;
       this.transformation = transformation;
       this.usesSpec = usesSpec;
+      this.supportsAad = supportsAad;
     }
 
     static CipherMode fromString(String modeName, String padding) {
-      if (modeName.equalsIgnoreCase(ECB.name) &&
-              (padding.equalsIgnoreCase("PKCS") || 
padding.equalsIgnoreCase("DEFAULT"))) {
+      boolean isNone = padding.equalsIgnoreCase("NONE");
+      boolean isPkcs = padding.equalsIgnoreCase("PKCS");
+      boolean isDefault = padding.equalsIgnoreCase("DEFAULT");
+      if (modeName.equalsIgnoreCase(ECB.name) && (isPkcs || isDefault)) {
         return ECB;
-      } else if (modeName.equalsIgnoreCase(CBC.name) &&
-              (padding.equalsIgnoreCase("PKCS") || 
padding.equalsIgnoreCase("DEFAULT"))) {
+      } else if (modeName.equalsIgnoreCase(CBC.name) && (isPkcs || isDefault)) 
{
         return CBC;
-      } else if (modeName.equalsIgnoreCase(GCM.name) &&
-              (padding.equalsIgnoreCase("NONE") || 
padding.equalsIgnoreCase("DEFAULT"))) {
+      } else if (modeName.equalsIgnoreCase(GCM.name) && (isNone || isDefault)) 
{
         return GCM;
       }
       throw QueryExecutionErrors.aesModeUnsupportedError(modeName, padding);
@@ -105,11 +112,44 @@ public class ExpressionImplUtils {
   }
 
   public static byte[] aesEncrypt(byte[] input, byte[] key, UTF8String mode, 
UTF8String padding) {
-    return aesInternal(input, key, mode.toString(), padding.toString(), 
Cipher.ENCRYPT_MODE);
+    return aesEncrypt(input, key, mode, padding, null, null);
   }
 
   public static byte[] aesDecrypt(byte[] input, byte[] key, UTF8String mode, 
UTF8String padding) {
-    return aesInternal(input, key, mode.toString(), padding.toString(), 
Cipher.DECRYPT_MODE);
+    return aesDecrypt(input, key, mode, padding, null);
+  }
+
+  public static byte[] aesEncrypt(byte[] input,
+                                  byte[] key,
+                                  UTF8String mode,
+                                  UTF8String padding,
+                                  byte[] iv,
+                                  byte[] aad) {
+    return aesInternal(
+            input,
+            key,
+            mode.toString(),
+            padding.toString(),
+            Cipher.ENCRYPT_MODE,
+            iv,
+            aad
+    );
+  }
+
+  public static byte[] aesDecrypt(byte[] input,
+                                  byte[] key,
+                                  UTF8String mode,
+                                  UTF8String padding,
+                                  byte[] aad) {
+    return aesInternal(
+            input,
+            key,
+            mode.toString(),
+            padding.toString(),
+            Cipher.DECRYPT_MODE,
+            null,
+            aad
+    );
   }
 
   private static SecretKeySpec getSecretKeySpec(byte[] key) {
@@ -143,20 +183,40 @@ public class ExpressionImplUtils {
       byte[] key,
       String mode,
       String padding,
-      int opmode) {
+      int opmode,
+      byte[] iv,
+      byte[] aad) {
     try {
       SecretKeySpec secretKey = getSecretKeySpec(key);
       CipherMode cipherMode = CipherMode.fromString(mode, padding);
       Cipher cipher = Cipher.getInstance(cipherMode.transformation);
       if (opmode == Cipher.ENCRYPT_MODE) {
-        // This IV will be 0-length for ECB
-        byte[] iv = generateIv(cipherMode);
+        // This may be 0-length for ECB
+        if (iv == null) {
+          iv = generateIv(cipherMode);
+        } else if (!cipherMode.usesSpec) {
+          // If the caller passes an IV, ensure the mode actually uses it.
+          throw QueryExecutionErrors.aesUnsupportedIv(mode);
+        }
+        if (iv.length != cipherMode.ivLength) {
+          throw QueryExecutionErrors.invalidAesIvLengthError(mode, iv.length);
+        }
+
         if (cipherMode.usesSpec) {
           AlgorithmParameterSpec algSpec = getParamSpec(cipherMode, iv);
           cipher.init(opmode, secretKey, algSpec);
         } else {
           cipher.init(opmode, secretKey);
         }
+
+        // If the cipher mode supports additional authenticated data and it is 
provided, update it
+        if (aad != null) {
+          if (cipherMode.supportsAad != true) {
+            throw QueryExecutionErrors.aesUnsupportedAad(mode);
+          }
+          cipher.updateAAD(aad);
+        }
+
         byte[] encrypted = cipher.doFinal(input, 0, input.length);
         if (iv.length > 0) {
           ByteBuffer byteBuffer = ByteBuffer.allocate(iv.length + 
encrypted.length);
@@ -171,6 +231,12 @@ public class ExpressionImplUtils {
         if (cipherMode.usesSpec) {
           AlgorithmParameterSpec algSpec = getParamSpec(cipherMode, input);
           cipher.init(opmode, secretKey, algSpec);
+          if (aad != null) {
+            if (cipherMode.supportsAad != true) {
+              throw QueryExecutionErrors.aesUnsupportedAad(mode);
+            }
+            cipher.updateAAD(aad);
+          }
           return cipher.doFinal(input, cipherMode.ivLength, input.length - 
cipherMode.ivLength);
         } else {
           cipher.init(opmode, secretKey);
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 67c5fa54732..5daa8ed3b7f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -2655,13 +2655,39 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase {
 
   def aesCryptoError(detailMessage: String): RuntimeException = {
     new SparkRuntimeException(
-      errorClass = "INVALID_PARAMETER_VALUE.AES_KEY",
+      errorClass = "INVALID_PARAMETER_VALUE.AES_CRYPTO_ERROR",
       messageParameters = Map(
         "parameter" -> (toSQLId("expr") + ", " + toSQLId("key")),
         "functionName" -> aesFuncName,
         "detailMessage" -> detailMessage))
   }
 
+  def invalidAesIvLengthError(mode: String, actualLength: Int): 
RuntimeException = {
+    new SparkRuntimeException(
+      errorClass = "INVALID_PARAMETER_VALUE.AES_IV_LENGTH",
+      messageParameters = Map(
+        "mode" -> mode,
+        "parameter" -> toSQLId("iv"),
+        "functionName" -> aesFuncName,
+        "actualLength" -> actualLength.toString()))
+  }
+
+  def aesUnsupportedIv(mode: String): RuntimeException = {
+    new SparkRuntimeException(
+      errorClass = "UNSUPPORTED_FEATURE.AES_MODE_IV",
+      messageParameters = Map(
+        "mode" -> mode,
+        "functionName" -> toSQLId("aes_encrypt")))
+  }
+
+  def aesUnsupportedAad(mode: String): RuntimeException = {
+    new SparkRuntimeException(
+      errorClass = "UNSUPPORTED_FEATURE.AES_MODE_AAD",
+      messageParameters = Map(
+        "mode" -> mode,
+        "functionName" -> toSQLId("aes_encrypt")))
+  }
+
   def hiveTableWithAnsiIntervalsError(tableName: String): 
SparkUnsupportedOperationException = {
     new SparkUnsupportedOperationException(
       errorClass = "_LEGACY_ERROR_TEMP_2276",
diff --git 
a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala
 
b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala
index 2c17ec8dbe7..52258156e31 100644
--- 
a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala
+++ 
b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala
@@ -17,21 +17,31 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
 import org.apache.spark.unsafe.types.UTF8String
 
 class ExpressionImplUtilsSuite extends SparkFunSuite {
+  private val b64decoder = java.util.Base64.getDecoder
+  private val b64encoder = java.util.Base64.getEncoder
+
   case class TestCase(
     plaintext: String,
     key: String,
     base64CiphertextExpected: String,
     mode: String,
-    padding: String = "Default") {
-    val plaintextBytes = plaintext.getBytes("UTF-8")
-    val keyBytes = key.getBytes("UTF-8")
-    val utf8mode = UTF8String.fromString(mode)
-    val utf8Padding = UTF8String.fromString(padding)
-    val deterministic = mode.equalsIgnoreCase("ECB")
+    padding: String = "Default",
+    ivHexOpt: Option[String] = None,
+    aadOpt: Option[String] = None,
+    expectedErrorClassOpt: Option[String] = None,
+    errorParamsMap: Map[String, String] = Map()) {
+    val plaintextBytes: Array[Byte] = plaintext.getBytes("UTF-8")
+    val keyBytes: Array[Byte] = key.getBytes("UTF-8")
+    val utf8mode: UTF8String = UTF8String.fromString(mode)
+    val utf8Padding: UTF8String = UTF8String.fromString(padding)
+    val deterministic: Boolean = mode.equalsIgnoreCase("ECB") || 
ivHexOpt.isDefined
+    val ivBytes: Array[Byte] =
+      ivHexOpt.map({ivHex => 
Hex.unhex(ivHex.getBytes("UTF-8"))}).getOrElse(null)
+    val aadBytes: Array[Byte] = aadOpt.map({aad => 
aad.getBytes("UTF-8")}).getOrElse(null)
   }
 
   val testCases = Seq(
@@ -85,29 +95,233 @@ class ExpressionImplUtilsSuite extends SparkFunSuite {
   )
 
   test("AesDecrypt Only") {
-    val decoder = java.util.Base64.getDecoder
-    testCases.foreach { t =>
-      val expectedBytes = decoder.decode(t.base64CiphertextExpected)
-      val decryptedBytes =
-        ExpressionImplUtils.aesDecrypt(expectedBytes, t.keyBytes, t.utf8mode, 
t.utf8Padding)
-      val decryptedString = new String(decryptedBytes)
-      assert(decryptedString == t.plaintext)
-    }
+    testCases.map(decOnlyCase)
   }
 
   test("AesEncrypt and AesDecrypt") {
-    val encoder = java.util.Base64.getEncoder
-    testCases.foreach { t =>
-      val ciphertextBytes =
-        ExpressionImplUtils.aesEncrypt(t.plaintextBytes, t.keyBytes, 
t.utf8mode, t.utf8Padding)
-      val ciphertextBase64 = encoder.encodeToString(ciphertextBytes)
-      val decryptedBytes =
-        ExpressionImplUtils.aesDecrypt(ciphertextBytes, t.keyBytes, 
t.utf8mode, t.utf8Padding)
-      val decryptedString = new String(decryptedBytes)
-      assert(decryptedString == t.plaintext)
-      if (t.deterministic) {
-        assert(t.base64CiphertextExpected == ciphertextBase64)
-      }
+    testCases.map(encDecCase)
+  }
+
+  val ivAadTestCases = Seq(
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "AAAAAAAAAAAAAAAAAAAAAPSd4mWyMZ5mhvjiAPQJnfg=",
+      "CBC",
+      ivHexOpt = Some("00000000000000000000000000000000")),
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "AAAAAAAAAAAAAAAAQiYi+sRNYDAOTjdSEcYBFsAWPL1f",
+      "GCM",
+      ivHexOpt = Some("000000000000000000000000")
+    ),
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4",
+      "GCM",
+      ivHexOpt = Some("000000000000000000000000"),
+      aadOpt = Some("This is an AAD mixed into the input")
+    ),
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4",
+      "GCM",
+      aadOpt = Some("This is an AAD mixed into the input")
+    )
+  )
+
+  test("AesDecrypt only with IVs or AADs") {
+    ivAadTestCases.map(decOnlyCase)
+  }
+
+  test("AesEncrypt and AesDecrypt with IVs or AADs") {
+    ivAadTestCases.map(encDecCase)
+  }
+
+  def decOnlyCase(t: TestCase): Unit = {
+    val expectedBytes = b64decoder.decode(t.base64CiphertextExpected)
+    val decryptedBytes = ExpressionImplUtils.aesDecrypt(
+      expectedBytes,
+      t.keyBytes,
+      t.utf8mode,
+      t.utf8Padding,
+      t.aadBytes
+    )
+    val decryptedString = new String(decryptedBytes)
+    assert(decryptedString == t.plaintext)
+  }
+
+  def encDecCase(t: TestCase): Unit = {
+    val ciphertextBytes = ExpressionImplUtils.aesEncrypt(
+      t.plaintextBytes,
+      t.keyBytes,
+      t.utf8mode,
+      t.utf8Padding,
+      t.ivBytes,
+      t.aadBytes
+    )
+    val ciphertextBase64 = b64encoder.encodeToString(ciphertextBytes)
+    val decryptedBytes = ExpressionImplUtils.aesDecrypt(
+      ciphertextBytes,
+      t.keyBytes,
+      t.utf8mode,
+      t.utf8Padding,
+      t.aadBytes
+    )
+    val decryptedString = new String(decryptedBytes)
+    assert(decryptedString == t.plaintext)
+    if (t.deterministic) {
+      assert(t.base64CiphertextExpected == ciphertextBase64)
+    }
+  }
+
+  val unsupportedErrorCases = Seq(
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "unused",
+      "ECB",
+      ivHexOpt = Some("0000000000000000"),
+      expectedErrorClassOpt = Some("UNSUPPORTED_FEATURE.AES_MODE_IV"),
+      errorParamsMap = Map(
+        "mode" -> "ECB",
+        "functionName" -> "`aes_encrypt`"
+      )
+    ),
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "unused",
+      "ECB",
+      aadOpt = Some("ECB does not support AAD mode"),
+      expectedErrorClassOpt = Some("UNSUPPORTED_FEATURE.AES_MODE_AAD"),
+      errorParamsMap = Map(
+        "mode" -> "ECB",
+        "functionName" -> "`aes_encrypt`"
+      )
+    ),
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "unused",
+      "CBC",
+      ivHexOpt = Some("0000000000"),
+      expectedErrorClassOpt = Some("INVALID_PARAMETER_VALUE.AES_IV_LENGTH"),
+      errorParamsMap = Map(
+        "mode" -> "CBC",
+        "parameter" -> "`iv`",
+        "functionName" -> "`aes_encrypt`/`aes_decrypt`",
+        "actualLength" -> "5"
+      )
+    ),
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "unused",
+      "GCM",
+      ivHexOpt = Some("0000000000"),
+      expectedErrorClassOpt = Some("INVALID_PARAMETER_VALUE.AES_IV_LENGTH"),
+      errorParamsMap = Map(
+        "mode" -> "GCM",
+        "parameter" -> "`iv`",
+        "functionName" -> "`aes_encrypt`/`aes_decrypt`",
+        "actualLength" -> "5"
+      )
+    ),
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "unused",
+      "GCM",
+      padding = "PKCS",
+      expectedErrorClassOpt = Some("UNSUPPORTED_FEATURE.AES_MODE"),
+      errorParamsMap = Map(
+        "mode" -> "GCM",
+        "padding" -> "PKCS",
+        "functionName" -> "`aes_encrypt`/`aes_decrypt`"
+      )
+    ),
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "unused",
+      "CBC",
+      aadOpt = Some("CBC doesn't support AADs"),
+      expectedErrorClassOpt = Some("UNSUPPORTED_FEATURE.AES_MODE_AAD"),
+      errorParamsMap = Map(
+        "mode" -> "CBC",
+        "functionName" -> "`aes_encrypt`"
+      )
+    )
+  )
+
+  test("AesEncrypt unsupported errors") {
+    unsupportedErrorCases.foreach { t =>
+      checkExpectedError(t, encDecCase)
+    }
+  }
+
+  val corruptedCiphertexts = Seq(
+    // This is truncated
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "+MgyzJxhusYVGWCljk7fhhl6C6oUqWmtdqoaG93=",
+      "CBC",
+      expectedErrorClassOpt = Some("INVALID_PARAMETER_VALUE.AES_CRYPTO_ERROR"),
+      errorParamsMap = Map(
+        "parameter" -> "`expr`, `key`",
+        "functionName" -> "`aes_encrypt`/`aes_decrypt`",
+        "detailMessage" ->
+          "Input length must be multiple of 16 when decrypting with padded 
cipher"
+      )
+    ),
+    // The ciphertext is corrupted
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "y5la3muiuxN2suj6VsYXB+1XUFjtrUD0/zv5eDafsA3U",
+      "GCM",
+      expectedErrorClassOpt = Some("INVALID_PARAMETER_VALUE.AES_CRYPTO_ERROR"),
+      errorParamsMap = Map(
+        "parameter" -> "`expr`, `key`",
+        "functionName" -> "`aes_encrypt`/`aes_decrypt`",
+        "detailMessage" -> "Tag mismatch!"
+      )
+    ),
+    // Valid ciphertext, wrong AAD
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4",
+      "GCM",
+      aadOpt = Some("The ciphertext is valid, but the AAD is wrong"),
+      expectedErrorClassOpt = Some("INVALID_PARAMETER_VALUE.AES_CRYPTO_ERROR"),
+      errorParamsMap = Map(
+        "parameter" -> "`expr`, `key`",
+        "functionName" -> "`aes_encrypt`/`aes_decrypt`",
+        "detailMessage" -> "Tag mismatch!"
+      )
+    )
+  )
+
+  test("AesEncrypt Expected Errors") {
+    corruptedCiphertexts.foreach { t =>
+      checkExpectedError(t, decOnlyCase)
     }
   }
+
+
+  private def checkExpectedError(t: TestCase, f: TestCase => Unit) = {
+    checkError(
+      exception = intercept[SparkRuntimeException] {
+        f(t)
+      },
+      errorClass = t.expectedErrorClassOpt.get,
+      parameters = t.errorParamsMap
+    )
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index 74c2ae71b3a..4bfab92ccb1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -121,7 +121,7 @@ class QueryExecutionErrorsSuite
     }
   }
 
-  test("INVALID_PARAMETER_VALUE.AES_KEY: AES decrypt failure - key mismatch") {
+  test("INVALID_PARAMETER_VALUE.AES_CRYPTO_ERROR: AES decrypt failure - key 
mismatch") {
     val (_, df2) = getAesInputs()
     Seq(
       ("value16", "1234567812345678"),
@@ -131,7 +131,7 @@ class QueryExecutionErrorsSuite
         exception = intercept[SparkException] {
           df2.selectExpr(s"aes_decrypt(unbase64($colName), binary('$key'), 
'ECB')").collect
         }.getCause.asInstanceOf[SparkRuntimeException],
-        errorClass = "INVALID_PARAMETER_VALUE.AES_KEY",
+        errorClass = "INVALID_PARAMETER_VALUE.AES_CRYPTO_ERROR",
         parameters = Map("parameter" -> "`expr`, `key`",
           "functionName" -> "`aes_encrypt`/`aes_decrypt`",
           "detailMessage" -> ("Given final block not properly padded. " +


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to