This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dabd771c37b [SPARK-43038][SQL] Support the CBC mode by 
`aes_encrypt()`/`aes_decrypt()`
dabd771c37b is described below

commit dabd771c37be9cbd773b5223d8c78226ece84f8a
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Wed Apr 12 16:02:29 2023 +0300

    [SPARK-43038][SQL] Support the CBC mode by `aes_encrypt()`/`aes_decrypt()`
    
    ### What changes were proposed in this pull request?
    In the PR, I propose new AES mode for the `aes_encrypt()`/`aes_decrypt()` 
functions - `CBC` ([Cipher Block 
Chaining](https://www.ibm.com/docs/en/linux-on-systems?topic=operation-cipher-block-chaining-cbc-mode))
 with the padding `PKCS7(5)`. The `aes_encrypt()` function returns a binary 
value which consists of the following fields:
    1. The salt magic prefix `Salted__` with the length of 8 bytes.
    2. A salt generated per every `aes_encrypt()` call using 
`java.security.SecureRandom`. Its length is 8 bytes.
    3. The encrypted input.
    
    The encrypt function derives the secret key and initialization vector (16 
bytes) from the salt and user's key using the same algorithm as OpenSSL's 
`EVP_BytesToKey()` (versions >= 1.1.0c).
    
    The `aes_decrypt()` functions assumes that its input has the fields as 
showed above.
    
    For example:
    ```sql
    spark-sql> SELECT base64(aes_encrypt('Apache Spark', '0000111122223333', 
'CBC', 'PKCS'));
    U2FsdGVkX1/ERGxwEOTDpDD4bQvDtQaNe+gXGudCcUk=
    spark-sql> SELECT 
aes_decrypt(unbase64('U2FsdGVkX1/ERGxwEOTDpDD4bQvDtQaNe+gXGudCcUk='), 
'0000111122223333', 'CBC', 'PKCS');
    Apache Spark
    ```
    
    ### Why are the changes needed?
    To achieve feature parity with other systems/frameworks, and make the 
migration process from them to Spark SQL easier. For example, the `CBC` mode is 
supported by:
    - BigQuery: 
https://cloud.google.com/bigquery/docs/reference/standard-sql/aead-encryption-concepts#block_cipher_modes
    - Snowflake: 
https://docs.snowflake.com/en/sql-reference/functions/encrypt.html
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    By running new checks:
    ```
    $ build/sbt "sql/testOnly *QueryExecutionErrorsSuite"
    $ build/sbt "sql/test:testOnly 
org.apache.spark.sql.expressions.ExpressionInfoSuite"
    $ build/sbt "test:testOnly org.apache.spark.sql.MiscFunctionsSuite"
    $ build/sbt "core/testOnly *SparkThrowableSuite"
    ```
    and checked compatibility with LibreSSL/OpenSSL:
    ```
    $ openssl version
    LibreSSL 3.3.6
    $ echo -n 'Apache Spark' | openssl enc -e -aes-128-cbc -pass 
pass:0000111122223333 -a
    U2FsdGVkX1+5GyAmmG7wDWWDBAuUuxjMy++cMFytpls=
    ```
    ```sql
    spark-sql (default)> SELECT 
aes_decrypt(unbase64('U2FsdGVkX1+5GyAmmG7wDWWDBAuUuxjMy++cMFytpls='), 
'0000111122223333', 'CBC');
    Apache Spark
    ```
    decrypt Spark's output by OpenSSL:
    ```sql
    spark-sql (default)> SELECT base64(aes_encrypt('Apache Spark', 
'abcdefghijklmnop12345678ABCDEFGH', 'CBC', 'PKCS'));
    U2FsdGVkX1+maU2vmxrulgxXuQSyZ3ODnlHKqnt2fDA=
    ```
    ```
    $ echo 'U2FsdGVkX1+maU2vmxrulgxXuQSyZ3ODnlHKqnt2fDA=' | openssl aes-256-cbc 
-a -d -pass pass:abcdefghijklmnop12345678ABCDEFGH
    Apache Spark
    ```
    
    Closes #40704 from MaxGekk/aes-cbc.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 core/src/main/resources/error/error-classes.json   |  5 ++
 .../catalyst/expressions/ExpressionImplUtils.java  | 72 ++++++++++++++++++++++
 .../spark/sql/catalyst/expressions/misc.scala      | 16 +++--
 .../spark/sql/errors/QueryExecutionErrors.scala    |  9 +++
 .../org/apache/spark/sql/MiscFunctionsSuite.scala  | 33 +++++-----
 .../sql/errors/QueryExecutionErrorsSuite.scala     | 31 ++++++++--
 6 files changed, 141 insertions(+), 25 deletions(-)

diff --git a/core/src/main/resources/error/error-classes.json 
b/core/src/main/resources/error/error-classes.json
index ae73071a120..1edf625fdc3 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -978,6 +978,11 @@
           "expects a binary value with 16, 24 or 32 bytes, but got 
<actualLength> bytes."
         ]
       },
+      "AES_SALTED_MAGIC" : {
+        "message" : [
+          "Initial bytes from input <saltedMagic> do not match 'Salted__' 
(0x53616C7465645F5F)."
+        ]
+      },
       "PATTERN" : {
         "message" : [
           "<value>."
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
index a6e482db57b..680ad11ad73 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
@@ -22,10 +22,16 @@ import org.apache.spark.unsafe.types.UTF8String;
 
 import javax.crypto.Cipher;
 import javax.crypto.spec.GCMParameterSpec;
+import javax.crypto.spec.IvParameterSpec;
 import javax.crypto.spec.SecretKeySpec;
 import java.nio.ByteBuffer;
 import java.security.GeneralSecurityException;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
 import java.security.SecureRandom;
+import java.util.Arrays;
+
+import static java.nio.charset.StandardCharsets.US_ASCII;
 
 /**
  * An utility class for constructing expressions.
@@ -35,6 +41,13 @@ public class ExpressionImplUtils {
   private static final int GCM_IV_LEN = 12;
   private static final int GCM_TAG_LEN = 128;
 
+  private static final int CBC_IV_LEN = 16;
+  private static final int CBC_SALT_LEN = 8;
+  /** OpenSSL's magic initial bytes. */
+  private static final String SALTED_STR = "Salted__";
+  private static final byte[] SALTED_MAGIC = SALTED_STR.getBytes(US_ASCII);
+
+
   /**
    * Function to check if a given number string is a valid Luhn number
    * @param numberString
@@ -115,6 +128,43 @@ public class ExpressionImplUtils {
           cipher.init(Cipher.DECRYPT_MODE, secretKey, parameterSpec);
           return cipher.doFinal(input, GCM_IV_LEN, input.length - GCM_IV_LEN);
         }
+      } else if (mode.equalsIgnoreCase("CBC") &&
+          (padding.equalsIgnoreCase("PKCS") || 
padding.equalsIgnoreCase("DEFAULT"))) {
+        Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
+        if (opmode == Cipher.ENCRYPT_MODE) {
+          byte[] salt = new byte[CBC_SALT_LEN];
+          secureRandom.nextBytes(salt);
+          final byte[] keyAndIv = getKeyAndIv(key, salt);
+          final byte[] keyValue = Arrays.copyOfRange(keyAndIv, 0, key.length);
+          final byte[] iv = Arrays.copyOfRange(keyAndIv, key.length, 
key.length + CBC_IV_LEN);
+          cipher.init(
+            Cipher.ENCRYPT_MODE,
+            new SecretKeySpec(keyValue, "AES"),
+            new IvParameterSpec(iv));
+          byte[] encrypted = cipher.doFinal(input, 0, input.length);
+          ByteBuffer byteBuffer = ByteBuffer.allocate(
+            SALTED_MAGIC.length + CBC_SALT_LEN + encrypted.length);
+          byteBuffer.put(SALTED_MAGIC);
+          byteBuffer.put(salt);
+          byteBuffer.put(encrypted);
+          return byteBuffer.array();
+        } else {
+          assert(opmode == Cipher.DECRYPT_MODE);
+          final byte[] shouldBeMagic = Arrays.copyOfRange(input, 0, 
SALTED_MAGIC.length);
+          if (!Arrays.equals(shouldBeMagic, SALTED_MAGIC)) {
+            throw QueryExecutionErrors.aesInvalidSalt(shouldBeMagic);
+          }
+          final byte[] salt = Arrays.copyOfRange(
+            input, SALTED_MAGIC.length, SALTED_MAGIC.length + CBC_SALT_LEN);
+          final byte[] keyAndIv = getKeyAndIv(key, salt);
+          final byte[] keyValue = Arrays.copyOfRange(keyAndIv, 0, key.length);
+          final byte[] iv = Arrays.copyOfRange(keyAndIv, key.length, 
key.length + CBC_IV_LEN);
+          cipher.init(
+            Cipher.DECRYPT_MODE,
+            new SecretKeySpec(keyValue, "AES"),
+            new IvParameterSpec(iv, 0, CBC_IV_LEN));
+          return cipher.doFinal(input, CBC_IV_LEN, input.length - CBC_IV_LEN);
+        }
       } else {
         throw QueryExecutionErrors.aesModeUnsupportedError(mode, padding);
       }
@@ -122,4 +172,26 @@ public class ExpressionImplUtils {
       throw QueryExecutionErrors.aesCryptoError(e.getMessage());
     }
   }
+
+  // Derive the key and init vector in the same way as OpenSSL's EVP_BytesToKey
+  // since the version 1.1.0c which switched to SHA-256 as the hash.
+  private static byte[] getKeyAndIv(byte[] key, byte[] salt) throws 
NoSuchAlgorithmException {
+    final byte[] keyAndSalt = arrConcat(key, salt);
+    byte[] hash = new byte[0];
+    byte[] keyAndIv = new byte[0];
+    for (int i = 0; i < 3 && keyAndIv.length < key.length + CBC_IV_LEN; i++) {
+      final byte[] hashData = arrConcat(hash, keyAndSalt);
+      final MessageDigest md = MessageDigest.getInstance("SHA-256");
+      hash = md.digest(hashData);
+      keyAndIv = arrConcat(keyAndIv, hash);
+    }
+    return keyAndIv;
+  }
+
+  private static byte[] arrConcat(final byte[] arr1, final byte[] arr2) {
+    final byte[] res = new byte[arr1.length + arr2.length];
+    System.arraycopy(arr1, 0, res, 0, arr1.length);
+    System.arraycopy(arr2, 0, res, arr1.length, arr2.length);
+    return res;
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 300fab0386c..00049cb113f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -313,7 +313,7 @@ case class CurrentUser() extends LeafExpression with 
Unevaluable {
 @ExpressionDescription(
   usage = """
     _FUNC_(expr, key[, mode[, padding]]) - Returns an encrypted value of 
`expr` using AES in given `mode` with the specified `padding`.
-      Key lengths of 16, 24 and 32 bits are supported. Supported combinations 
of (`mode`, `padding`) are ('ECB', 'PKCS') and ('GCM', 'NONE').
+      Key lengths of 16, 24 and 32 bits are supported. Supported combinations 
of (`mode`, `padding`) are ('ECB', 'PKCS'), ('GCM', 'NONE') and ('CBC', 'PKCS').
       The default mode is GCM.
   """,
   arguments = """
@@ -321,9 +321,9 @@ case class CurrentUser() extends LeafExpression with 
Unevaluable {
       * expr - The binary value to encrypt.
       * key - The passphrase to use to encrypt the data.
       * mode - Specifies which block cipher mode should be used to encrypt 
messages.
-               Valid modes: ECB, GCM.
+               Valid modes: ECB, GCM, CBC.
       * padding - Specifies how to pad messages whose length is not a multiple 
of the block size.
-                  Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means 
PKCS for ECB and NONE for GCM.
+                  Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means 
PKCS for ECB, NONE for GCM and PKCS for CBC.
   """,
   examples = """
     Examples:
@@ -333,6 +333,8 @@ case class CurrentUser() extends LeafExpression with 
Unevaluable {
        
6E7CA17BBB468D3084B5744BCA729FB7B2B7BCB8E4472847D02670489D95FA97DBBA7D3210
       > SELECT base64(_FUNC_('Spark SQL', '1234567890abcdef', 'ECB', 'PKCS'));
        3lmwu+Mw0H3fi5NDvcu9lg==
+      > SELECT base64(_FUNC_('Apache Spark', '1234567890abcdef', 'CBC', 
'DEFAULT'));
+       U2FsdGVkX18JQ84pfRUwonUrFzpWQ46vKu4+MkJVFGM=
   """,
   since = "3.3.0",
   group = "misc_funcs")
@@ -377,7 +379,7 @@ case class AesEncrypt(
 @ExpressionDescription(
   usage = """
     _FUNC_(expr, key[, mode[, padding]]) - Returns a decrypted value of `expr` 
using AES in `mode` with `padding`.
-      Key lengths of 16, 24 and 32 bits are supported. Supported combinations 
of (`mode`, `padding`) are ('ECB', 'PKCS') and ('GCM', 'NONE').
+      Key lengths of 16, 24 and 32 bits are supported. Supported combinations 
of (`mode`, `padding`) are ('ECB', 'PKCS'), ('GCM', 'NONE') and ('CBC', 'PKCS').
       The default mode is GCM.
   """,
   arguments = """
@@ -385,9 +387,9 @@ case class AesEncrypt(
       * expr - The binary value to decrypt.
       * key - The passphrase to use to decrypt the data.
       * mode - Specifies which block cipher mode should be used to decrypt 
messages.
-               Valid modes: ECB, GCM.
+               Valid modes: ECB, GCM, CBC.
       * padding - Specifies how to pad messages whose length is not a multiple 
of the block size.
-                  Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means 
PKCS for ECB and NONE for GCM.
+                  Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means 
PKCS for ECB, NONE for GCM and PKCS for CBC.
   """,
   examples = """
     Examples:
@@ -397,6 +399,8 @@ case class AesEncrypt(
        Spark SQL
       > SELECT _FUNC_(unbase64('3lmwu+Mw0H3fi5NDvcu9lg=='), 
'1234567890abcdef', 'ECB', 'PKCS');
        Spark SQL
+      > SELECT 
_FUNC_(unbase64('U2FsdGVkX18JQ84pfRUwonUrFzpWQ46vKu4+MkJVFGM='), 
'1234567890abcdef', 'CBC');
+       Apache Spark
   """,
   since = "3.3.0",
   group = "misc_funcs")
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index d07dcec3693..11fe84990c1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -2651,6 +2651,15 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase {
         "detailMessage" -> detailMessage))
   }
 
+  def aesInvalidSalt(saltedMagic: Array[Byte]): RuntimeException = {
+    new SparkRuntimeException(
+      errorClass = "INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC",
+      messageParameters = Map(
+        "parameter" -> toSQLId("expr"),
+        "functionName" -> toSQLId("aes_decrypt"),
+        "saltedMagic" -> saltedMagic.map("%02X" format _).mkString("0x", "", 
"")))
+  }
+
   def hiveTableWithAnsiIntervalsError(tableName: String): 
SparkUnsupportedOperationException = {
     new SparkUnsupportedOperationException(
       errorClass = "_LEGACY_ERROR_TEMP_2276",
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala
index 45ae3e54977..d498982fb2d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala
@@ -62,21 +62,26 @@ class MiscFunctionsSuite extends QueryTest with 
SharedSparkSession {
     }
   }
 
-  test("SPARK-37591: AES functions - GCM mode") {
+  test("SPARK-37591, SPARK-43038: AES functions - GCM/CBC mode") {
     Seq(
-      ("abcdefghijklmnop", ""),
-      ("abcdefghijklmnop", "abcdefghijklmnop"),
-      ("abcdefghijklmnop12345678", "Spark"),
-      ("abcdefghijklmnop12345678ABCDEFGH", "GCM mode")
-    ).foreach { case (key, input) =>
-      val df = Seq((key, input)).toDF("key", "input")
-      val encrypted = df.selectExpr("aes_encrypt(input, key, 'GCM', 'NONE') AS 
enc", "input", "key")
-      assert(encrypted.schema("enc").dataType === BinaryType)
-      assert(encrypted.filter($"enc" === $"input").isEmpty)
-      val result = encrypted.selectExpr(
-        "CAST(aes_decrypt(enc, key, 'GCM', 'NONE') AS STRING) AS res", "input")
-      assert(!result.filter($"res" === $"input").isEmpty &&
-        result.filter($"res" =!= $"input").isEmpty)
+      "GCM" -> "NONE",
+      "CBC" -> "PKCS").foreach { case (mode, padding) =>
+      Seq(
+        ("abcdefghijklmnop", ""),
+        ("abcdefghijklmnop", "abcdefghijklmnop"),
+        ("abcdefghijklmnop12345678", "Spark"),
+        ("abcdefghijklmnop12345678ABCDEFGH", "GCM mode")
+      ).foreach { case (key, input) =>
+        val df = Seq((key, input)).toDF("key", "input")
+        val encrypted = df.selectExpr(
+          s"aes_encrypt(input, key, '$mode', '$padding') AS enc", "input", 
"key")
+        assert(encrypted.schema("enc").dataType === BinaryType)
+        assert(encrypted.filter($"enc" === $"input").isEmpty)
+        val result = encrypted.selectExpr(
+          s"CAST(aes_decrypt(enc, key, '$mode', '$padding') AS STRING) AS 
res", "input")
+        assert(!result.filter($"res" === $"input").isEmpty &&
+          result.filter($"res" =!= $"input").isEmpty)
+      }
     }
   }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index f2d4ade167a..de59eec505a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -140,6 +140,25 @@ class QueryExecutionErrorsSuite
     }
   }
 
+  test("INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC: AES decrypt failure - 
invalid salt") {
+    checkError(
+      exception = intercept[SparkRuntimeException] {
+        sql(
+          """
+            |SELECT aes_decrypt(
+            |  unbase64('INVALID_SALT_ERGxwEOTDpDD4bQvDtQaNe+gXGudCcUk='),
+            |  '0000111122223333',
+            |  'CBC', 'PKCS')
+            |""".stripMargin).collect()
+      },
+      errorClass = "INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC",
+      parameters = Map(
+        "parameter" -> "`expr`",
+        "functionName" -> "`aes_decrypt`",
+        "saltedMagic" -> "0x20D5402C80D200B4"),
+      sqlState = "22023")
+  }
+
   test("UNSUPPORTED_FEATURE: unsupported combinations of AES modes and 
padding") {
     val key16 = "abcdefghijklmnop"
     val key32 = "abcdefghijklmnop12345678ABCDEFGH"
@@ -157,18 +176,20 @@ class QueryExecutionErrorsSuite
     }
 
     // Unsupported AES mode and padding in encrypt
-    checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 
'CBC')"),
-      "CBC", "DEFAULT")
+    checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'CBC', 
'None')"),
+      "CBC", "None")
     checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'ECB', 
'NoPadding')"),
       "ECB", "NoPadding")
 
     // Unsupported AES mode and padding in decrypt
     checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value16, '$key16', 
'GSM')"),
-    "GSM", "DEFAULT")
+      "GSM", "DEFAULT")
     checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value16, '$key16', 
'GCM', 'PKCS')"),
-    "GCM", "PKCS")
+      "GCM", "PKCS")
     checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value32, '$key32', 
'ECB', 'None')"),
-    "ECB", "None")
+      "ECB", "None")
+    checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value32, '$key32', 
'CBC', 'NoPadding')"),
+      "CBC", "NoPadding")
   }
 
   test("UNSUPPORTED_FEATURE: unsupported types (map and struct) in lit()") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to