This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2643223e21b [SPARK-43286][SQL] Updates aes_encrypt CBC mode to 
generate random IVs
2643223e21b is described below

commit 2643223e21b6e80ea150b41a99c040ef7eebd51a
Author: Steve Weis <steve.w...@databricks.com>
AuthorDate: Tue May 9 11:56:28 2023 +0300

    [SPARK-43286][SQL] Updates aes_encrypt CBC mode to generate random IVs
    
    ### What changes were proposed in this pull request?
    
    The current implementation of AES-CBC mode called via `aes_encrypt` and 
`aes_decrypt` uses a key derivation function (KDF) based on OpenSSL's 
[EVP_BytesToKey](https://www.openssl.org/docs/man3.0/man3/EVP_BytesToKey.html). 
This is intended for generating keys based on passwords and OpenSSL's documents 
discourage its use: "Newer applications should use a more modern algorithm".
    
    `aes_encrypt` and `aes_decrypt` should use the key directly in CBC mode, as 
it does for both GCM and ECB mode. The output should then be the initialization 
vector (IV) prepended to the ciphertext – as is done with GCM mode:
    `[16-byte randomly generated IV | AES-CBC encrypted ciphertext]`
    
    ### Why are the changes needed?
    
    We want to have the ciphertext output similar across different modes. 
OpenSSL's EVP_BytesToKey is effectively deprecated and their own documentation 
says not to use it. Instead, CBC mode will generate a random vector.
    
    ### Does this PR introduce _any_ user-facing change?
    
    AES-CBC output generated by the previous format will be incompatible with 
this change. That change was recently landed and we want to land this before 
CBC mode is used in practice.
    
    ### How was this patch tested?
    
    A new unit test in `DataFrameFunctionsSuite` was added to test both GCM and 
CBC modes. Also, a new standalone unit test suite was added in 
`ExpressionImplUtilsSuite` to test all the modes and various key lengths.
    ```
    build/sbt "sql/test:testOnly org.apache.spark.sql.DataFrameFunctionsSuite"
    build/sbt "sql/test:testOnly 
org.apache.spark.sql.catalyst.expressions.ExpressionImplUtilsSuite"
    ```
    
    CBC values can be verified with `openssl enc` using the following command:
    ```
    echo -n "[INPUT]" | openssl enc -a -e -aes-256-cbc -iv [HEX IV] -K [HEX KEY]
    echo -n "Spark" | openssl enc -a -e -aes-256-cbc -iv 
f8c832cc9c61bac6151960a58e4edf86 -K 
6162636465666768696a6b6c6d6e6f7031323334353637384142434445464748
    ```
    
    Closes #40969 from sweisdb/SPARK-43286.
    
    Authored-by: Steve Weis <steve.w...@databricks.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 core/src/main/resources/error/error-classes.json   |   5 -
 .../catalyst/expressions/ExpressionImplUtils.java  | 187 ++++++++++-----------
 .../spark/sql/catalyst/expressions/misc.scala      |   4 +-
 .../spark/sql/errors/QueryExecutionErrors.scala    |   9 -
 .../apache/spark/sql/DataFrameFunctionsSuite.scala |  36 +++-
 .../expressions/ExpressionImplUtilsSuite.scala     | 113 +++++++++++++
 .../sql/errors/QueryExecutionErrorsSuite.scala     |  19 ---
 7 files changed, 232 insertions(+), 141 deletions(-)

diff --git a/core/src/main/resources/error/error-classes.json 
b/core/src/main/resources/error/error-classes.json
index 8b0d98c7e3d..dc97a735b39 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -1016,11 +1016,6 @@
           "expects a binary value with 16, 24 or 32 bytes, but got 
<actualLength> bytes."
         ]
       },
-      "AES_SALTED_MAGIC" : {
-        "message" : [
-          "Initial bytes from input <saltedMagic> do not match 'Salted__' 
(0x53616C7465645F5F)."
-        ]
-      },
       "PATTERN" : {
         "message" : [
           "<value>."
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
index 680ad11ad73..6843a348006 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
@@ -26,27 +26,54 @@ import javax.crypto.spec.IvParameterSpec;
 import javax.crypto.spec.SecretKeySpec;
 import java.nio.ByteBuffer;
 import java.security.GeneralSecurityException;
-import java.security.MessageDigest;
-import java.security.NoSuchAlgorithmException;
 import java.security.SecureRandom;
-import java.util.Arrays;
+import java.security.spec.AlgorithmParameterSpec;
 
-import static java.nio.charset.StandardCharsets.US_ASCII;
 
 /**
- * An utility class for constructing expressions.
+ * A utility class for constructing expressions.
  */
 public class ExpressionImplUtils {
-  private static final SecureRandom secureRandom = new SecureRandom();
+  private static final ThreadLocal<SecureRandom> threadLocalSecureRandom =
+          ThreadLocal.withInitial(SecureRandom::new);
+
   private static final int GCM_IV_LEN = 12;
   private static final int GCM_TAG_LEN = 128;
-
   private static final int CBC_IV_LEN = 16;
-  private static final int CBC_SALT_LEN = 8;
-  /** OpenSSL's magic initial bytes. */
-  private static final String SALTED_STR = "Salted__";
-  private static final byte[] SALTED_MAGIC = SALTED_STR.getBytes(US_ASCII);
 
+  enum CipherMode {
+    ECB("ECB", 0, 0, "AES/ECB/PKCS5Padding", false),
+    CBC("CBC", CBC_IV_LEN, 0, "AES/CBC/PKCS5Padding", true),
+    GCM("GCM", GCM_IV_LEN, GCM_TAG_LEN, "AES/GCM/NoPadding", true);
+
+    private final String name;
+    final int ivLength;
+    final int tagLength;
+    final String transformation;
+    final boolean usesSpec;
+
+    CipherMode(String name, int ivLen, int tagLen, String transformation, 
boolean usesSpec) {
+      this.name = name;
+      this.ivLength = ivLen;
+      this.tagLength = tagLen;
+      this.transformation = transformation;
+      this.usesSpec = usesSpec;
+    }
+
+    static CipherMode fromString(String modeName, String padding) {
+      if (modeName.equalsIgnoreCase(ECB.name) &&
+              (padding.equalsIgnoreCase("PKCS") || 
padding.equalsIgnoreCase("DEFAULT"))) {
+        return ECB;
+      } else if (modeName.equalsIgnoreCase(CBC.name) &&
+              (padding.equalsIgnoreCase("PKCS") || 
padding.equalsIgnoreCase("DEFAULT"))) {
+        return CBC;
+      } else if (modeName.equalsIgnoreCase(GCM.name) &&
+              (padding.equalsIgnoreCase("NONE") || 
padding.equalsIgnoreCase("DEFAULT"))) {
+        return GCM;
+      }
+      throw QueryExecutionErrors.aesModeUnsupportedError(modeName, padding);
+    }
+  }
 
   /**
    * Function to check if a given number string is a valid Luhn number
@@ -85,113 +112,73 @@ public class ExpressionImplUtils {
     return aesInternal(input, key, mode.toString(), padding.toString(), 
Cipher.DECRYPT_MODE);
   }
 
+  private static SecretKeySpec getSecretKeySpec(byte[] key) {
+    switch (key.length) {
+      case 16: case 24: case 32:
+        return new SecretKeySpec(key, 0, key.length, "AES");
+      default:
+        throw QueryExecutionErrors.invalidAesKeyLengthError(key.length);
+    }
+  }
+
+  private static byte[] generateIv(CipherMode mode) {
+    byte[] iv = new byte[mode.ivLength];
+    threadLocalSecureRandom.get().nextBytes(iv);
+    return iv;
+  }
+
+  private static AlgorithmParameterSpec getParamSpec(CipherMode mode, byte[] 
input) {
+    switch (mode) {
+      case CBC:
+        return new IvParameterSpec(input, 0, mode.ivLength);
+      case GCM:
+        return new GCMParameterSpec(mode.tagLength, input, 0, mode.ivLength);
+      default:
+        return null;
+    }
+  }
+
   private static byte[] aesInternal(
       byte[] input,
       byte[] key,
       String mode,
       String padding,
       int opmode) {
-    SecretKeySpec secretKey;
-
-    switch (key.length) {
-      case 16:
-      case 24:
-      case 32:
-        secretKey = new SecretKeySpec(key, 0, key.length, "AES");
-        break;
-      default:
-        throw QueryExecutionErrors.invalidAesKeyLengthError(key.length);
-      }
-
     try {
-      if (mode.equalsIgnoreCase("ECB") &&
-          (padding.equalsIgnoreCase("PKCS") || 
padding.equalsIgnoreCase("DEFAULT"))) {
-        Cipher cipher = Cipher.getInstance("AES/ECB/PKCS5Padding");
-        cipher.init(opmode, secretKey);
-        return cipher.doFinal(input, 0, input.length);
-      } else if (mode.equalsIgnoreCase("GCM") &&
-          (padding.equalsIgnoreCase("NONE") || 
padding.equalsIgnoreCase("DEFAULT"))) {
-        Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
-        if (opmode == Cipher.ENCRYPT_MODE) {
-          byte[] iv = new byte[GCM_IV_LEN];
-          secureRandom.nextBytes(iv);
-          GCMParameterSpec parameterSpec = new GCMParameterSpec(GCM_TAG_LEN, 
iv);
-          cipher.init(Cipher.ENCRYPT_MODE, secretKey, parameterSpec);
-          byte[] encrypted = cipher.doFinal(input, 0, input.length);
+      SecretKeySpec secretKey = getSecretKeySpec(key);
+      CipherMode cipherMode = CipherMode.fromString(mode, padding);
+      Cipher cipher = Cipher.getInstance(cipherMode.transformation);
+      if (opmode == Cipher.ENCRYPT_MODE) {
+        // This IV will be 0-length for ECB
+        byte[] iv = generateIv(cipherMode);
+        if (cipherMode.usesSpec) {
+          AlgorithmParameterSpec algSpec = getParamSpec(cipherMode, iv);
+          cipher.init(opmode, secretKey, algSpec);
+        } else {
+          cipher.init(opmode, secretKey);
+        }
+        byte[] encrypted = cipher.doFinal(input, 0, input.length);
+        if (iv.length > 0) {
           ByteBuffer byteBuffer = ByteBuffer.allocate(iv.length + 
encrypted.length);
           byteBuffer.put(iv);
           byteBuffer.put(encrypted);
           return byteBuffer.array();
         } else {
-          assert(opmode == Cipher.DECRYPT_MODE);
-          GCMParameterSpec parameterSpec = new GCMParameterSpec(GCM_TAG_LEN, 
input, 0, GCM_IV_LEN);
-          cipher.init(Cipher.DECRYPT_MODE, secretKey, parameterSpec);
-          return cipher.doFinal(input, GCM_IV_LEN, input.length - GCM_IV_LEN);
+          return encrypted;
         }
-      } else if (mode.equalsIgnoreCase("CBC") &&
-          (padding.equalsIgnoreCase("PKCS") || 
padding.equalsIgnoreCase("DEFAULT"))) {
-        Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
-        if (opmode == Cipher.ENCRYPT_MODE) {
-          byte[] salt = new byte[CBC_SALT_LEN];
-          secureRandom.nextBytes(salt);
-          final byte[] keyAndIv = getKeyAndIv(key, salt);
-          final byte[] keyValue = Arrays.copyOfRange(keyAndIv, 0, key.length);
-          final byte[] iv = Arrays.copyOfRange(keyAndIv, key.length, 
key.length + CBC_IV_LEN);
-          cipher.init(
-            Cipher.ENCRYPT_MODE,
-            new SecretKeySpec(keyValue, "AES"),
-            new IvParameterSpec(iv));
-          byte[] encrypted = cipher.doFinal(input, 0, input.length);
-          ByteBuffer byteBuffer = ByteBuffer.allocate(
-            SALTED_MAGIC.length + CBC_SALT_LEN + encrypted.length);
-          byteBuffer.put(SALTED_MAGIC);
-          byteBuffer.put(salt);
-          byteBuffer.put(encrypted);
-          return byteBuffer.array();
+      } else {
+        assert(opmode == Cipher.DECRYPT_MODE);
+        if (cipherMode.usesSpec) {
+          AlgorithmParameterSpec algSpec = getParamSpec(cipherMode, input);
+          cipher.init(opmode, secretKey, algSpec);
+          return cipher.doFinal(input, cipherMode.ivLength, input.length - 
cipherMode.ivLength);
         } else {
-          assert(opmode == Cipher.DECRYPT_MODE);
-          final byte[] shouldBeMagic = Arrays.copyOfRange(input, 0, 
SALTED_MAGIC.length);
-          if (!Arrays.equals(shouldBeMagic, SALTED_MAGIC)) {
-            throw QueryExecutionErrors.aesInvalidSalt(shouldBeMagic);
-          }
-          final byte[] salt = Arrays.copyOfRange(
-            input, SALTED_MAGIC.length, SALTED_MAGIC.length + CBC_SALT_LEN);
-          final byte[] keyAndIv = getKeyAndIv(key, salt);
-          final byte[] keyValue = Arrays.copyOfRange(keyAndIv, 0, key.length);
-          final byte[] iv = Arrays.copyOfRange(keyAndIv, key.length, 
key.length + CBC_IV_LEN);
-          cipher.init(
-            Cipher.DECRYPT_MODE,
-            new SecretKeySpec(keyValue, "AES"),
-            new IvParameterSpec(iv, 0, CBC_IV_LEN));
-          return cipher.doFinal(input, CBC_IV_LEN, input.length - CBC_IV_LEN);
+          cipher.init(opmode, secretKey);
+          return cipher.doFinal(input, 0, input.length);
         }
-      } else {
-        throw QueryExecutionErrors.aesModeUnsupportedError(mode, padding);
       }
     } catch (GeneralSecurityException e) {
       throw QueryExecutionErrors.aesCryptoError(e.getMessage());
     }
   }
-
-  // Derive the key and init vector in the same way as OpenSSL's EVP_BytesToKey
-  // since the version 1.1.0c which switched to SHA-256 as the hash.
-  private static byte[] getKeyAndIv(byte[] key, byte[] salt) throws 
NoSuchAlgorithmException {
-    final byte[] keyAndSalt = arrConcat(key, salt);
-    byte[] hash = new byte[0];
-    byte[] keyAndIv = new byte[0];
-    for (int i = 0; i < 3 && keyAndIv.length < key.length + CBC_IV_LEN; i++) {
-      final byte[] hashData = arrConcat(hash, keyAndSalt);
-      final MessageDigest md = MessageDigest.getInstance("SHA-256");
-      hash = md.digest(hashData);
-      keyAndIv = arrConcat(keyAndIv, hash);
-    }
-    return keyAndIv;
-  }
-
-  private static byte[] arrConcat(final byte[] arr1, final byte[] arr2) {
-    final byte[] res = new byte[arr1.length + arr2.length];
-    System.arraycopy(arr1, 0, res, 0, arr1.length);
-    System.arraycopy(arr2, 0, res, arr1.length, arr2.length);
-    return res;
-  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 00049cb113f..67328cde71a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -334,7 +334,7 @@ case class CurrentUser() extends LeafExpression with 
Unevaluable {
       > SELECT base64(_FUNC_('Spark SQL', '1234567890abcdef', 'ECB', 'PKCS'));
        3lmwu+Mw0H3fi5NDvcu9lg==
       > SELECT base64(_FUNC_('Apache Spark', '1234567890abcdef', 'CBC', 
'DEFAULT'));
-       U2FsdGVkX18JQ84pfRUwonUrFzpWQ46vKu4+MkJVFGM=
+       2NYmDCjgXTbbxGA3/SnJEfFC/JQ7olk2VQWReIAAFKo=
   """,
   since = "3.3.0",
   group = "misc_funcs")
@@ -399,7 +399,7 @@ case class AesEncrypt(
        Spark SQL
       > SELECT _FUNC_(unbase64('3lmwu+Mw0H3fi5NDvcu9lg=='), 
'1234567890abcdef', 'ECB', 'PKCS');
        Spark SQL
-      > SELECT 
_FUNC_(unbase64('U2FsdGVkX18JQ84pfRUwonUrFzpWQ46vKu4+MkJVFGM='), 
'1234567890abcdef', 'CBC');
+      > SELECT 
_FUNC_(unbase64('2NYmDCjgXTbbxGA3/SnJEfFC/JQ7olk2VQWReIAAFKo='), 
'1234567890abcdef', 'CBC');
        Apache Spark
   """,
   since = "3.3.0",
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index bc51727f8fb..26191a7dcba 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -2656,15 +2656,6 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase {
         "detailMessage" -> detailMessage))
   }
 
-  def aesInvalidSalt(saltedMagic: Array[Byte]): RuntimeException = {
-    new SparkRuntimeException(
-      errorClass = "INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC",
-      messageParameters = Map(
-        "parameter" -> toSQLId("expr"),
-        "functionName" -> toSQLId("aes_decrypt"),
-        "saltedMagic" -> saltedMagic.map("%02X" format _).mkString("0x", "", 
"")))
-  }
-
   def hiveTableWithAnsiIntervalsError(tableName: String): 
SparkUnsupportedOperationException = {
     new SparkUnsupportedOperationException(
       errorClass = "_LEGACY_ERROR_TEMP_2276",
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 5386150c8a7..3d1a5048748 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -344,6 +344,30 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("misc aes function") {
+    val key32 = "abcdefghijklmnop12345678ABCDEFGH"
+    val encryptedEcb = "9J3iZbIxnmaG+OIA9Amd+A=="
+    val encryptedGcm = "y5la3muiuxN2suj6VsYXB+0XUFjtrUD0/zv5eDafsA3U"
+    val encryptedCbc = "+MgyzJxhusYVGWCljk7fhhl6C6oUqWmtdqoaG93KvhY="
+    val df1 = Seq("Spark").toDF
+
+    // Successful decryption of fixed values
+    Seq(
+      (key32, encryptedEcb, "ECB"),
+      (key32, encryptedGcm, "GCM"),
+      (key32, encryptedCbc, "CBC")).foreach {
+      case (key, encryptedText, mode) =>
+        checkAnswer(
+          df1.selectExpr(
+            s"cast(aes_decrypt(unbase64('$encryptedText'), '$key', '$mode') as 
string)"),
+          Seq(Row("Spark")))
+        checkAnswer(
+          df1.selectExpr(
+            s"cast(aes_decrypt(unbase64('$encryptedText'), binary('$key'), 
'$mode') as string)"),
+          Seq(Row("Spark")))
+    }
+  }
+
+  test("misc aes ECB function") {
     val key16 = "abcdefghijklmnop"
     val key24 = "abcdefghijklmnop12345678"
     val key32 = "abcdefghijklmnop12345678ABCDEFGH"
@@ -358,15 +382,15 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
 
     // Successful encryption
     Seq(
-      (key16, encryptedText16, encryptedEmptyText16),
-      (key24, encryptedText24, encryptedEmptyText24),
-      (key32, encryptedText32, encryptedEmptyText32)).foreach {
-      case (key, encryptedText, encryptedEmptyText) =>
+      (key16, encryptedText16, encryptedEmptyText16, "ECB"),
+      (key24, encryptedText24, encryptedEmptyText24, "ECB"),
+      (key32, encryptedText32, encryptedEmptyText32, "ECB")).foreach {
+      case (key, encryptedText, encryptedEmptyText, mode) =>
         checkAnswer(
-          df1.selectExpr(s"base64(aes_encrypt(value, '$key', 'ECB'))"),
+          df1.selectExpr(s"base64(aes_encrypt(value, '$key', '$mode'))"),
           Seq(Row(encryptedText), Row(encryptedEmptyText)))
         checkAnswer(
-          df1.selectExpr(s"base64(aes_encrypt(binary(value), '$key', 'ECB'))"),
+          df1.selectExpr(s"base64(aes_encrypt(binary(value), '$key', 
'$mode'))"),
           Seq(Row(encryptedText), Row(encryptedEmptyText)))
     }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala
new file mode 100644
index 00000000000..2c17ec8dbe7
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.unsafe.types.UTF8String
+
+class ExpressionImplUtilsSuite extends SparkFunSuite {
+  case class TestCase(
+    plaintext: String,
+    key: String,
+    base64CiphertextExpected: String,
+    mode: String,
+    padding: String = "Default") {
+    val plaintextBytes = plaintext.getBytes("UTF-8")
+    val keyBytes = key.getBytes("UTF-8")
+    val utf8mode = UTF8String.fromString(mode)
+    val utf8Padding = UTF8String.fromString(padding)
+    val deterministic = mode.equalsIgnoreCase("ECB")
+  }
+
+  val testCases = Seq(
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop",
+      "4Hv0UKCx6nfUeAoPZo1z+w==",
+      "ECB"),
+    TestCase("Spark",
+      "abcdefghijklmnop12345678",
+      "NeTYNgA+PCQBN50DA//O2w==",
+      "ECB"),
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "9J3iZbIxnmaG+OIA9Amd+A==",
+      "ECB"),
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "+MgyzJxhusYVGWCljk7fhhl6C6oUqWmtdqoaG93KvhY=",
+      "CBC"),
+    TestCase(
+      "Apache Spark",
+      "1234567890abcdef",
+      "2NYmDCjgXTbbxGA3/SnJEfFC/JQ7olk2VQWReIAAFKo=",
+      "CBC"),
+    TestCase(
+      "Spark",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "y5la3muiuxN2suj6VsYXB+0XUFjtrUD0/zv5eDafsA3U",
+      "GCM"),
+    TestCase(
+      "This message is longer than a single AES block and should work fine.",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "agUfTbLT8KPsqbAmQn/YdpohvxqX5bBsfFjtxE5UwqvO6EWSUVy" +
+        "jeDA6r30XyS0ARebsBgXKSExaAVZ40NMgDLQa6/o9pieYwLT5YXI7flU=",
+      "ECB"),
+    TestCase(
+      "This message is longer than a single AES block and should work fine.",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "cxUKNdlZa/6hT6gdhp46OThPcdNONdBwJj/Ctl6z4gWVKfcA6DE" +
+        
"lJg84LbkueIifjNOTloduKgidk9G9a4BDsn0NjlGLUeG8GH1moPWb/+knBC7oT/OOA06W6rJXudDo",
+      "CBC"),
+    TestCase(
+      "This message is longer than a single AES block and should work fine.",
+      "abcdefghijklmnop12345678ABCDEFGH",
+      "73B0tHM3F7bvmG7yIZB9vMKnzHyuCYjD9PzAI7NJ+kDBWtaFO22" +
+        
"n2cKlkNcCzr45a4Uol+sNtQwQAV7iRhBdt6YmXoviemyXJWOZ89G279SgxabaomEIyN/HZwenxeN4",
+      "GCM")
+  )
+
+  test("AesDecrypt Only") {
+    val decoder = java.util.Base64.getDecoder
+    testCases.foreach { t =>
+      val expectedBytes = decoder.decode(t.base64CiphertextExpected)
+      val decryptedBytes =
+        ExpressionImplUtils.aesDecrypt(expectedBytes, t.keyBytes, t.utf8mode, 
t.utf8Padding)
+      val decryptedString = new String(decryptedBytes)
+      assert(decryptedString == t.plaintext)
+    }
+  }
+
+  test("AesEncrypt and AesDecrypt") {
+    val encoder = java.util.Base64.getEncoder
+    testCases.foreach { t =>
+      val ciphertextBytes =
+        ExpressionImplUtils.aesEncrypt(t.plaintextBytes, t.keyBytes, 
t.utf8mode, t.utf8Padding)
+      val ciphertextBase64 = encoder.encodeToString(ciphertextBytes)
+      val decryptedBytes =
+        ExpressionImplUtils.aesDecrypt(ciphertextBytes, t.keyBytes, 
t.utf8mode, t.utf8Padding)
+      val decryptedString = new String(decryptedBytes)
+      assert(decryptedString == t.plaintext)
+      if (t.deterministic) {
+        assert(t.base64CiphertextExpected == ciphertextBase64)
+      }
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index fe1cd7a4808..13dee57e8ff 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -140,25 +140,6 @@ class QueryExecutionErrorsSuite
     }
   }
 
-  test("INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC: AES decrypt failure - 
invalid salt") {
-    checkError(
-      exception = intercept[SparkRuntimeException] {
-        sql(
-          """
-            |SELECT aes_decrypt(
-            |  unbase64('INVALID_SALT_ERGxwEOTDpDD4bQvDtQaNe+gXGudCcUk='),
-            |  '0000111122223333',
-            |  'CBC', 'PKCS')
-            |""".stripMargin).collect()
-      },
-      errorClass = "INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC",
-      parameters = Map(
-        "parameter" -> "`expr`",
-        "functionName" -> "`aes_decrypt`",
-        "saltedMagic" -> "0x20D5402C80D200B4"),
-      sqlState = "22023")
-  }
-
   test("UNSUPPORTED_FEATURE: unsupported combinations of AES modes and 
padding") {
     val key16 = "abcdefghijklmnop"
     val key32 = "abcdefghijklmnop12345678ABCDEFGH"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to