This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new dabd771c37b [SPARK-43038][SQL] Support the CBC mode by `aes_encrypt()`/`aes_decrypt()` dabd771c37b is described below commit dabd771c37be9cbd773b5223d8c78226ece84f8a Author: Max Gekk <max.g...@gmail.com> AuthorDate: Wed Apr 12 16:02:29 2023 +0300 [SPARK-43038][SQL] Support the CBC mode by `aes_encrypt()`/`aes_decrypt()` ### What changes were proposed in this pull request? In the PR, I propose new AES mode for the `aes_encrypt()`/`aes_decrypt()` functions - `CBC` ([Cipher Block Chaining](https://www.ibm.com/docs/en/linux-on-systems?topic=operation-cipher-block-chaining-cbc-mode)) with the padding `PKCS7(5)`. The `aes_encrypt()` function returns a binary value which consists of the following fields: 1. The salt magic prefix `Salted__` with the length of 8 bytes. 2. A salt generated per every `aes_encrypt()` call using `java.security.SecureRandom`. Its length is 8 bytes. 3. The encrypted input. The encrypt function derives the secret key and initialization vector (16 bytes) from the salt and user's key using the same algorithm as OpenSSL's `EVP_BytesToKey()` (versions >= 1.1.0c). The `aes_decrypt()` functions assumes that its input has the fields as showed above. For example: ```sql spark-sql> SELECT base64(aes_encrypt('Apache Spark', '0000111122223333', 'CBC', 'PKCS')); U2FsdGVkX1/ERGxwEOTDpDD4bQvDtQaNe+gXGudCcUk= spark-sql> SELECT aes_decrypt(unbase64('U2FsdGVkX1/ERGxwEOTDpDD4bQvDtQaNe+gXGudCcUk='), '0000111122223333', 'CBC', 'PKCS'); Apache Spark ``` ### Why are the changes needed? To achieve feature parity with other systems/frameworks, and make the migration process from them to Spark SQL easier. For example, the `CBC` mode is supported by: - BigQuery: https://cloud.google.com/bigquery/docs/reference/standard-sql/aead-encryption-concepts#block_cipher_modes - Snowflake: https://docs.snowflake.com/en/sql-reference/functions/encrypt.html ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By running new checks: ``` $ build/sbt "sql/testOnly *QueryExecutionErrorsSuite" $ build/sbt "sql/test:testOnly org.apache.spark.sql.expressions.ExpressionInfoSuite" $ build/sbt "test:testOnly org.apache.spark.sql.MiscFunctionsSuite" $ build/sbt "core/testOnly *SparkThrowableSuite" ``` and checked compatibility with LibreSSL/OpenSSL: ``` $ openssl version LibreSSL 3.3.6 $ echo -n 'Apache Spark' | openssl enc -e -aes-128-cbc -pass pass:0000111122223333 -a U2FsdGVkX1+5GyAmmG7wDWWDBAuUuxjMy++cMFytpls= ``` ```sql spark-sql (default)> SELECT aes_decrypt(unbase64('U2FsdGVkX1+5GyAmmG7wDWWDBAuUuxjMy++cMFytpls='), '0000111122223333', 'CBC'); Apache Spark ``` decrypt Spark's output by OpenSSL: ```sql spark-sql (default)> SELECT base64(aes_encrypt('Apache Spark', 'abcdefghijklmnop12345678ABCDEFGH', 'CBC', 'PKCS')); U2FsdGVkX1+maU2vmxrulgxXuQSyZ3ODnlHKqnt2fDA= ``` ``` $ echo 'U2FsdGVkX1+maU2vmxrulgxXuQSyZ3ODnlHKqnt2fDA=' | openssl aes-256-cbc -a -d -pass pass:abcdefghijklmnop12345678ABCDEFGH Apache Spark ``` Closes #40704 from MaxGekk/aes-cbc. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- core/src/main/resources/error/error-classes.json | 5 ++ .../catalyst/expressions/ExpressionImplUtils.java | 72 ++++++++++++++++++++++ .../spark/sql/catalyst/expressions/misc.scala | 16 +++-- .../spark/sql/errors/QueryExecutionErrors.scala | 9 +++ .../org/apache/spark/sql/MiscFunctionsSuite.scala | 33 +++++----- .../sql/errors/QueryExecutionErrorsSuite.scala | 31 ++++++++-- 6 files changed, 141 insertions(+), 25 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index ae73071a120..1edf625fdc3 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -978,6 +978,11 @@ "expects a binary value with 16, 24 or 32 bytes, but got <actualLength> bytes." ] }, + "AES_SALTED_MAGIC" : { + "message" : [ + "Initial bytes from input <saltedMagic> do not match 'Salted__' (0x53616C7465645F5F)." + ] + }, "PATTERN" : { "message" : [ "<value>." diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java index a6e482db57b..680ad11ad73 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java @@ -22,10 +22,16 @@ import org.apache.spark.unsafe.types.UTF8String; import javax.crypto.Cipher; import javax.crypto.spec.GCMParameterSpec; +import javax.crypto.spec.IvParameterSpec; import javax.crypto.spec.SecretKeySpec; import java.nio.ByteBuffer; import java.security.GeneralSecurityException; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; +import java.util.Arrays; + +import static java.nio.charset.StandardCharsets.US_ASCII; /** * An utility class for constructing expressions. @@ -35,6 +41,13 @@ public class ExpressionImplUtils { private static final int GCM_IV_LEN = 12; private static final int GCM_TAG_LEN = 128; + private static final int CBC_IV_LEN = 16; + private static final int CBC_SALT_LEN = 8; + /** OpenSSL's magic initial bytes. */ + private static final String SALTED_STR = "Salted__"; + private static final byte[] SALTED_MAGIC = SALTED_STR.getBytes(US_ASCII); + + /** * Function to check if a given number string is a valid Luhn number * @param numberString @@ -115,6 +128,43 @@ public class ExpressionImplUtils { cipher.init(Cipher.DECRYPT_MODE, secretKey, parameterSpec); return cipher.doFinal(input, GCM_IV_LEN, input.length - GCM_IV_LEN); } + } else if (mode.equalsIgnoreCase("CBC") && + (padding.equalsIgnoreCase("PKCS") || padding.equalsIgnoreCase("DEFAULT"))) { + Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding"); + if (opmode == Cipher.ENCRYPT_MODE) { + byte[] salt = new byte[CBC_SALT_LEN]; + secureRandom.nextBytes(salt); + final byte[] keyAndIv = getKeyAndIv(key, salt); + final byte[] keyValue = Arrays.copyOfRange(keyAndIv, 0, key.length); + final byte[] iv = Arrays.copyOfRange(keyAndIv, key.length, key.length + CBC_IV_LEN); + cipher.init( + Cipher.ENCRYPT_MODE, + new SecretKeySpec(keyValue, "AES"), + new IvParameterSpec(iv)); + byte[] encrypted = cipher.doFinal(input, 0, input.length); + ByteBuffer byteBuffer = ByteBuffer.allocate( + SALTED_MAGIC.length + CBC_SALT_LEN + encrypted.length); + byteBuffer.put(SALTED_MAGIC); + byteBuffer.put(salt); + byteBuffer.put(encrypted); + return byteBuffer.array(); + } else { + assert(opmode == Cipher.DECRYPT_MODE); + final byte[] shouldBeMagic = Arrays.copyOfRange(input, 0, SALTED_MAGIC.length); + if (!Arrays.equals(shouldBeMagic, SALTED_MAGIC)) { + throw QueryExecutionErrors.aesInvalidSalt(shouldBeMagic); + } + final byte[] salt = Arrays.copyOfRange( + input, SALTED_MAGIC.length, SALTED_MAGIC.length + CBC_SALT_LEN); + final byte[] keyAndIv = getKeyAndIv(key, salt); + final byte[] keyValue = Arrays.copyOfRange(keyAndIv, 0, key.length); + final byte[] iv = Arrays.copyOfRange(keyAndIv, key.length, key.length + CBC_IV_LEN); + cipher.init( + Cipher.DECRYPT_MODE, + new SecretKeySpec(keyValue, "AES"), + new IvParameterSpec(iv, 0, CBC_IV_LEN)); + return cipher.doFinal(input, CBC_IV_LEN, input.length - CBC_IV_LEN); + } } else { throw QueryExecutionErrors.aesModeUnsupportedError(mode, padding); } @@ -122,4 +172,26 @@ public class ExpressionImplUtils { throw QueryExecutionErrors.aesCryptoError(e.getMessage()); } } + + // Derive the key and init vector in the same way as OpenSSL's EVP_BytesToKey + // since the version 1.1.0c which switched to SHA-256 as the hash. + private static byte[] getKeyAndIv(byte[] key, byte[] salt) throws NoSuchAlgorithmException { + final byte[] keyAndSalt = arrConcat(key, salt); + byte[] hash = new byte[0]; + byte[] keyAndIv = new byte[0]; + for (int i = 0; i < 3 && keyAndIv.length < key.length + CBC_IV_LEN; i++) { + final byte[] hashData = arrConcat(hash, keyAndSalt); + final MessageDigest md = MessageDigest.getInstance("SHA-256"); + hash = md.digest(hashData); + keyAndIv = arrConcat(keyAndIv, hash); + } + return keyAndIv; + } + + private static byte[] arrConcat(final byte[] arr1, final byte[] arr2) { + final byte[] res = new byte[arr1.length + arr2.length]; + System.arraycopy(arr1, 0, res, 0, arr1.length); + System.arraycopy(arr2, 0, res, arr1.length, arr2.length); + return res; + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 300fab0386c..00049cb113f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -313,7 +313,7 @@ case class CurrentUser() extends LeafExpression with Unevaluable { @ExpressionDescription( usage = """ _FUNC_(expr, key[, mode[, padding]]) - Returns an encrypted value of `expr` using AES in given `mode` with the specified `padding`. - Key lengths of 16, 24 and 32 bits are supported. Supported combinations of (`mode`, `padding`) are ('ECB', 'PKCS') and ('GCM', 'NONE'). + Key lengths of 16, 24 and 32 bits are supported. Supported combinations of (`mode`, `padding`) are ('ECB', 'PKCS'), ('GCM', 'NONE') and ('CBC', 'PKCS'). The default mode is GCM. """, arguments = """ @@ -321,9 +321,9 @@ case class CurrentUser() extends LeafExpression with Unevaluable { * expr - The binary value to encrypt. * key - The passphrase to use to encrypt the data. * mode - Specifies which block cipher mode should be used to encrypt messages. - Valid modes: ECB, GCM. + Valid modes: ECB, GCM, CBC. * padding - Specifies how to pad messages whose length is not a multiple of the block size. - Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means PKCS for ECB and NONE for GCM. + Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means PKCS for ECB, NONE for GCM and PKCS for CBC. """, examples = """ Examples: @@ -333,6 +333,8 @@ case class CurrentUser() extends LeafExpression with Unevaluable { 6E7CA17BBB468D3084B5744BCA729FB7B2B7BCB8E4472847D02670489D95FA97DBBA7D3210 > SELECT base64(_FUNC_('Spark SQL', '1234567890abcdef', 'ECB', 'PKCS')); 3lmwu+Mw0H3fi5NDvcu9lg== + > SELECT base64(_FUNC_('Apache Spark', '1234567890abcdef', 'CBC', 'DEFAULT')); + U2FsdGVkX18JQ84pfRUwonUrFzpWQ46vKu4+MkJVFGM= """, since = "3.3.0", group = "misc_funcs") @@ -377,7 +379,7 @@ case class AesEncrypt( @ExpressionDescription( usage = """ _FUNC_(expr, key[, mode[, padding]]) - Returns a decrypted value of `expr` using AES in `mode` with `padding`. - Key lengths of 16, 24 and 32 bits are supported. Supported combinations of (`mode`, `padding`) are ('ECB', 'PKCS') and ('GCM', 'NONE'). + Key lengths of 16, 24 and 32 bits are supported. Supported combinations of (`mode`, `padding`) are ('ECB', 'PKCS'), ('GCM', 'NONE') and ('CBC', 'PKCS'). The default mode is GCM. """, arguments = """ @@ -385,9 +387,9 @@ case class AesEncrypt( * expr - The binary value to decrypt. * key - The passphrase to use to decrypt the data. * mode - Specifies which block cipher mode should be used to decrypt messages. - Valid modes: ECB, GCM. + Valid modes: ECB, GCM, CBC. * padding - Specifies how to pad messages whose length is not a multiple of the block size. - Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means PKCS for ECB and NONE for GCM. + Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means PKCS for ECB, NONE for GCM and PKCS for CBC. """, examples = """ Examples: @@ -397,6 +399,8 @@ case class AesEncrypt( Spark SQL > SELECT _FUNC_(unbase64('3lmwu+Mw0H3fi5NDvcu9lg=='), '1234567890abcdef', 'ECB', 'PKCS'); Spark SQL + > SELECT _FUNC_(unbase64('U2FsdGVkX18JQ84pfRUwonUrFzpWQ46vKu4+MkJVFGM='), '1234567890abcdef', 'CBC'); + Apache Spark """, since = "3.3.0", group = "misc_funcs") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index d07dcec3693..11fe84990c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2651,6 +2651,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { "detailMessage" -> detailMessage)) } + def aesInvalidSalt(saltedMagic: Array[Byte]): RuntimeException = { + new SparkRuntimeException( + errorClass = "INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC", + messageParameters = Map( + "parameter" -> toSQLId("expr"), + "functionName" -> toSQLId("aes_decrypt"), + "saltedMagic" -> saltedMagic.map("%02X" format _).mkString("0x", "", ""))) + } + def hiveTableWithAnsiIntervalsError(tableName: String): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( errorClass = "_LEGACY_ERROR_TEMP_2276", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala index 45ae3e54977..d498982fb2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala @@ -62,21 +62,26 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession { } } - test("SPARK-37591: AES functions - GCM mode") { + test("SPARK-37591, SPARK-43038: AES functions - GCM/CBC mode") { Seq( - ("abcdefghijklmnop", ""), - ("abcdefghijklmnop", "abcdefghijklmnop"), - ("abcdefghijklmnop12345678", "Spark"), - ("abcdefghijklmnop12345678ABCDEFGH", "GCM mode") - ).foreach { case (key, input) => - val df = Seq((key, input)).toDF("key", "input") - val encrypted = df.selectExpr("aes_encrypt(input, key, 'GCM', 'NONE') AS enc", "input", "key") - assert(encrypted.schema("enc").dataType === BinaryType) - assert(encrypted.filter($"enc" === $"input").isEmpty) - val result = encrypted.selectExpr( - "CAST(aes_decrypt(enc, key, 'GCM', 'NONE') AS STRING) AS res", "input") - assert(!result.filter($"res" === $"input").isEmpty && - result.filter($"res" =!= $"input").isEmpty) + "GCM" -> "NONE", + "CBC" -> "PKCS").foreach { case (mode, padding) => + Seq( + ("abcdefghijklmnop", ""), + ("abcdefghijklmnop", "abcdefghijklmnop"), + ("abcdefghijklmnop12345678", "Spark"), + ("abcdefghijklmnop12345678ABCDEFGH", "GCM mode") + ).foreach { case (key, input) => + val df = Seq((key, input)).toDF("key", "input") + val encrypted = df.selectExpr( + s"aes_encrypt(input, key, '$mode', '$padding') AS enc", "input", "key") + assert(encrypted.schema("enc").dataType === BinaryType) + assert(encrypted.filter($"enc" === $"input").isEmpty) + val result = encrypted.selectExpr( + s"CAST(aes_decrypt(enc, key, '$mode', '$padding') AS STRING) AS res", "input") + assert(!result.filter($"res" === $"input").isEmpty && + result.filter($"res" =!= $"input").isEmpty) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index f2d4ade167a..de59eec505a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -140,6 +140,25 @@ class QueryExecutionErrorsSuite } } + test("INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC: AES decrypt failure - invalid salt") { + checkError( + exception = intercept[SparkRuntimeException] { + sql( + """ + |SELECT aes_decrypt( + | unbase64('INVALID_SALT_ERGxwEOTDpDD4bQvDtQaNe+gXGudCcUk='), + | '0000111122223333', + | 'CBC', 'PKCS') + |""".stripMargin).collect() + }, + errorClass = "INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC", + parameters = Map( + "parameter" -> "`expr`", + "functionName" -> "`aes_decrypt`", + "saltedMagic" -> "0x20D5402C80D200B4"), + sqlState = "22023") + } + test("UNSUPPORTED_FEATURE: unsupported combinations of AES modes and padding") { val key16 = "abcdefghijklmnop" val key32 = "abcdefghijklmnop12345678ABCDEFGH" @@ -157,18 +176,20 @@ class QueryExecutionErrorsSuite } // Unsupported AES mode and padding in encrypt - checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'CBC')"), - "CBC", "DEFAULT") + checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'CBC', 'None')"), + "CBC", "None") checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'ECB', 'NoPadding')"), "ECB", "NoPadding") // Unsupported AES mode and padding in decrypt checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value16, '$key16', 'GSM')"), - "GSM", "DEFAULT") + "GSM", "DEFAULT") checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value16, '$key16', 'GCM', 'PKCS')"), - "GCM", "PKCS") + "GCM", "PKCS") checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value32, '$key32', 'ECB', 'None')"), - "ECB", "None") + "ECB", "None") + checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value32, '$key32', 'CBC', 'NoPadding')"), + "CBC", "NoPadding") } test("UNSUPPORTED_FEATURE: unsupported types (map and struct) in lit()") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org