This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 4dd6923787 fix: Fix SparkSha2 to be compliant with Spark response and 
add support for Int32 (#16350)
4dd6923787 is described below

commit 4dd6923787084548c9ecc6d90c630c2c28ee9259
Author: Rishab Joshi <8187657+rish...@users.noreply.github.com>
AuthorDate: Fri Jun 13 10:45:33 2025 -0700

    fix: Fix SparkSha2 to be compliant with Spark response and add support for 
Int32 (#16350)
    
    * fix: Fix SparkSha2 to be compliant with Spark response and add support 
for Int32.
    
    * Fixed test cases.
    
    * Addressed comments.
    
    * Fixed missed test case.
    
    * Minor cosmetic changes.
---
 datafusion/spark/src/function/hash/sha2.rs         | 61 ++++++++++++----------
 datafusion/spark/src/function/math/hex.rs          | 31 ++++++++---
 .../sqllogictest/test_files/spark/hash/sha2.slt    | 48 ++++++++++-------
 3 files changed, 85 insertions(+), 55 deletions(-)

diff --git a/datafusion/spark/src/function/hash/sha2.rs 
b/datafusion/spark/src/function/hash/sha2.rs
index b4b29ef334..a8bb8c21a2 100644
--- a/datafusion/spark/src/function/hash/sha2.rs
+++ b/datafusion/spark/src/function/hash/sha2.rs
@@ -20,9 +20,9 @@ extern crate datafusion_functions;
 use crate::function::error_utils::{
     invalid_arg_count_exec_err, unsupported_data_type_exec_err,
 };
-use crate::function::math::hex::spark_hex;
+use crate::function::math::hex::spark_sha2_hex;
 use arrow::array::{ArrayRef, AsArray, StringArray};
-use arrow::datatypes::{DataType, UInt32Type};
+use arrow::datatypes::{DataType, Int32Type};
 use datafusion_common::{exec_err, internal_datafusion_err, Result, 
ScalarValue};
 use datafusion_expr::Signature;
 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, 
Volatility};
@@ -121,7 +121,7 @@ impl ScalarUDFImpl for SparkSha2 {
             )),
         }?;
         let bit_length_type = if arg_types[1].is_numeric() {
-            Ok(DataType::UInt32)
+            Ok(DataType::Int32)
         } else if arg_types[1].is_null() {
             Ok(DataType::Null)
         } else {
@@ -138,39 +138,24 @@ impl ScalarUDFImpl for SparkSha2 {
 
 pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
     match args {
-        [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), 
ColumnarValue::Scalar(ScalarValue::UInt32(Some(bit_length_arg)))] => {
-            match bit_length_arg {
-                0 | 256 => 
sha256(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]),
-                224 => 
sha224(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]),
-                384 => 
sha384(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]),
-                512 => 
sha512(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]),
-                _ => exec_err!(
-                    "sha2 function only supports 224, 256, 384, and 512 bit 
lengths."
-                ),
-            }
-            .map(|hashed| spark_hex(&[hashed]).unwrap())
+        [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), 
ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => {
+            compute_sha2(
+                bit_length_arg,
+                &[ColumnarValue::from(ScalarValue::Utf8(expr_arg))],
+            )
         }
-        [ColumnarValue::Array(expr_arg), 
ColumnarValue::Scalar(ScalarValue::UInt32(Some(bit_length_arg)))] => {
-            match bit_length_arg {
-                0 | 256 => sha256(&[ColumnarValue::from(expr_arg)]),
-                224 => sha224(&[ColumnarValue::from(expr_arg)]),
-                384 => sha384(&[ColumnarValue::from(expr_arg)]),
-                512 => sha512(&[ColumnarValue::from(expr_arg)]),
-                _ => exec_err!(
-                    "sha2 function only supports 224, 256, 384, and 512 bit 
lengths."
-                ),
-            }
-            .map(|hashed| spark_hex(&[hashed]).unwrap())
+        [ColumnarValue::Array(expr_arg), 
ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => {
+            compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)])
         }
         [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), 
ColumnarValue::Array(bit_length_arg)] =>
         {
             let arr: StringArray = bit_length_arg
-                .as_primitive::<UInt32Type>()
+                .as_primitive::<Int32Type>()
                 .iter()
                 .map(|bit_length| {
                     match sha2([
                         
ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())),
-                        ColumnarValue::Scalar(ScalarValue::UInt32(bit_length)),
+                        ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
                     ])
                     .unwrap()
                     {
@@ -188,7 +173,7 @@ pub fn sha2(args: [ColumnarValue; 2]) -> 
Result<ColumnarValue> {
         }
         [ColumnarValue::Array(expr_arg), ColumnarValue::Array(bit_length_arg)] 
=> {
             let expr_iter = expr_arg.as_string::<i32>().iter();
-            let bit_length_iter = 
bit_length_arg.as_primitive::<UInt32Type>().iter();
+            let bit_length_iter = 
bit_length_arg.as_primitive::<Int32Type>().iter();
             let arr: StringArray = expr_iter
                 .zip(bit_length_iter)
                 .map(|(expr, bit_length)| {
@@ -196,7 +181,7 @@ pub fn sha2(args: [ColumnarValue; 2]) -> 
Result<ColumnarValue> {
                         ColumnarValue::Scalar(ScalarValue::Utf8(Some(
                             expr.unwrap().to_string(),
                         ))),
-                        ColumnarValue::Scalar(ScalarValue::UInt32(bit_length)),
+                        ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
                     ])
                     .unwrap()
                     {
@@ -215,3 +200,21 @@ pub fn sha2(args: [ColumnarValue; 2]) -> 
Result<ColumnarValue> {
         _ => exec_err!("Unsupported argument types for sha2 function"),
     }
 }
+
+fn compute_sha2(
+    bit_length_arg: i32,
+    expr_arg: &[ColumnarValue],
+) -> Result<ColumnarValue> {
+    match bit_length_arg {
+        0 | 256 => sha256(expr_arg),
+        224 => sha224(expr_arg),
+        384 => sha384(expr_arg),
+        512 => sha512(expr_arg),
+        _ => {
+            // Return null for unsupported bit lengths instead of error, 
because spark sha2 does not
+            // error out for this.
+            return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
+        }
+    }
+    .map(|hashed| spark_sha2_hex(&[hashed]).unwrap())
+}
diff --git a/datafusion/spark/src/function/math/hex.rs 
b/datafusion/spark/src/function/math/hex.rs
index 74ec7641b3..614d1d4e9a 100644
--- a/datafusion/spark/src/function/math/hex.rs
+++ b/datafusion/spark/src/function/math/hex.rs
@@ -159,13 +159,28 @@ fn hex_encode<T: AsRef<[u8]>>(data: T, lower_case: bool) 
-> String {
 }
 
 #[inline(always)]
-fn hex_bytes<T: AsRef<[u8]>>(bytes: T) -> Result<String, std::fmt::Error> {
-    let hex_string = hex_encode(bytes, false);
+fn hex_bytes<T: AsRef<[u8]>>(
+    bytes: T,
+    lowercase: bool,
+) -> Result<String, std::fmt::Error> {
+    let hex_string = hex_encode(bytes, lowercase);
     Ok(hex_string)
 }
 
 /// Spark-compatible `hex` function
 pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, 
DataFusionError> {
+    compute_hex(args, false)
+}
+
+/// Spark-compatible `sha2` function
+pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, 
DataFusionError> {
+    compute_hex(args, true)
+}
+
+pub fn compute_hex(
+    args: &[ColumnarValue],
+    lowercase: bool,
+) -> Result<ColumnarValue, DataFusionError> {
     if args.len() != 1 {
         return Err(DataFusionError::Internal(
             "hex expects exactly one argument".to_string(),
@@ -192,7 +207,7 @@ pub fn spark_hex(args: &[ColumnarValue]) -> 
Result<ColumnarValue, DataFusionErro
 
                 let hexed: StringArray = array
                     .iter()
-                    .map(|v| v.map(hex_bytes).transpose())
+                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
                     .collect::<Result<_, _>>()?;
 
                 Ok(ColumnarValue::Array(Arc::new(hexed)))
@@ -202,7 +217,7 @@ pub fn spark_hex(args: &[ColumnarValue]) -> 
Result<ColumnarValue, DataFusionErro
 
                 let hexed: StringArray = array
                     .iter()
-                    .map(|v| v.map(hex_bytes).transpose())
+                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
                     .collect::<Result<_, _>>()?;
 
                 Ok(ColumnarValue::Array(Arc::new(hexed)))
@@ -212,7 +227,7 @@ pub fn spark_hex(args: &[ColumnarValue]) -> 
Result<ColumnarValue, DataFusionErro
 
                 let hexed: StringArray = array
                     .iter()
-                    .map(|v| v.map(hex_bytes).transpose())
+                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
                     .collect::<Result<_, _>>()?;
 
                 Ok(ColumnarValue::Array(Arc::new(hexed)))
@@ -222,7 +237,7 @@ pub fn spark_hex(args: &[ColumnarValue]) -> 
Result<ColumnarValue, DataFusionErro
 
                 let hexed: StringArray = array
                     .iter()
-                    .map(|v| v.map(hex_bytes).transpose())
+                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
                     .collect::<Result<_, _>>()?;
 
                 Ok(ColumnarValue::Array(Arc::new(hexed)))
@@ -237,11 +252,11 @@ pub fn spark_hex(args: &[ColumnarValue]) -> 
Result<ColumnarValue, DataFusionErro
                         .collect::<Vec<_>>(),
                     DataType::Utf8 => as_string_array(dict.values())
                         .iter()
-                        .map(|v| v.map(hex_bytes).transpose())
+                        .map(|v| v.map(|b| hex_bytes(b, 
lowercase)).transpose())
                         .collect::<Result<_, _>>()?,
                     DataType::Binary => as_binary_array(dict.values())?
                         .iter()
-                        .map(|v| v.map(hex_bytes).transpose())
+                        .map(|v| v.map(|b| hex_bytes(b, 
lowercase)).transpose())
                         .collect::<Result<_, _>>()?,
                     _ => exec_err!(
                         "hex got an unexpected argument type: {:?}",
diff --git a/datafusion/sqllogictest/test_files/spark/hash/sha2.slt 
b/datafusion/sqllogictest/test_files/spark/hash/sha2.slt
index e2341df164..7690a38773 100644
--- a/datafusion/sqllogictest/test_files/spark/hash/sha2.slt
+++ b/datafusion/sqllogictest/test_files/spark/hash/sha2.slt
@@ -18,48 +18,60 @@
 query T
 SELECT sha2('Spark', 0::INT);
 ----
-529BC3B07127ECB7E53A4DCF1991D9152C24537D919178022B2C42657F79A26B
+529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b
 
 query T
 SELECT sha2('Spark', 256::INT);
 ----
-529BC3B07127ECB7E53A4DCF1991D9152C24537D919178022B2C42657F79A26B
+529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b
 
 query T
 SELECT sha2('Spark', 224::INT);
 ----
-DBEAB94971678D36AF2195851C0F7485775A2A7C60073D62FC04549C
+dbeab94971678d36af2195851c0f7485775a2a7c60073d62fc04549c
 
 query T
 SELECT sha2('Spark', 384::INT);
 ----
-1E40B8D06C248A1CC32428C22582B6219D072283078FA140D9AD297ECADF2CABEFC341B857AD36226AA8D6D79F2AB67D
+1e40b8d06c248a1cc32428c22582b6219d072283078fa140d9ad297ecadf2cabefc341b857ad36226aa8d6d79f2ab67d
 
 query T
 SELECT sha2('Spark', 512::INT);
 ----
-44844A586C54C9A212DA1DBFE05C5F1705DE1AF5FDA1F0D36297623249B279FD8F0CCEC03F888F4FB13BF7CD83FDAD58591C797F81121A23CFDD5E0897795238
+44844a586c54c9a212da1dbfe05c5f1705de1af5fda1f0d36297623249b279fd8f0ccec03f888f4fb13bf7cd83fdad58591c797f81121a23cfdd5e0897795238
+
+query T
+SELECT sha2('Spark', 128::INT);
+----
+NULL
 
 query T
 SELECT sha2(expr, 256::INT) FROM VALUES ('foo'), ('bar') AS t(expr);
 ----
-2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE
-FCDE2B2EDBA56BF408601FB721FE9B5C338D10EE429EA04FAE5511B68FBF8FB9
+2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
+fcde2b2edba56bf408601fb721fe9b5c338d10ee429ea04fae5511b68fbf8fb9
 
 query T
-SELECT sha2('foo', bit_length) FROM VALUES (0::INT), (256::INT), (224::INT), 
(384::INT), (512::INT) AS t(bit_length);
+SELECT sha2(expr, 128::INT) FROM VALUES ('foo'), ('bar') AS t(expr);
 ----
-2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE
-2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE
-0808F64E60D58979FCB676C96EC938270DEA42445AEEFCD3A4E6F8DB
-98C11FFDFDD540676B1A137CB1A22B2A70350C9A44171D6B1180C6BE5CBB2EE3F79D532C8A1DD9EF2E8E08E752A3BABB
-F7FBBA6E0636F890E56FBBF3283E524C6FA3204AE298382D624741D0DC6638326E282C41BE5E4254D8820772C5518A2C5A8C0C7F7EDA19594A7EB539453E1ED7
+NULL
+NULL
 
+query T
+SELECT sha2('foo', bit_length) FROM VALUES (0::INT), (256::INT), (224::INT), 
(384::INT), (512::INT), (128::INT) AS t(bit_length);
+----
+2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
+2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
+0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
+98c11ffdfdd540676b1a137cb1a22b2a70350c9a44171d6b1180c6be5cbb2ee3f79d532c8a1dd9ef2e8e08e752a3babb
+f7fbba6e0636f890e56fbbf3283e524c6fa3204ae298382d624741d0dc6638326e282c41be5e4254d8820772c5518a2c5a8c0c7f7eda19594a7eb539453e1ed7
+NULL
 
 query T
-SELECT sha2(expr, bit_length) FROM VALUES ('foo',0::INT), ('bar',224::INT), 
('baz',384::INT), ('qux',512::INT) AS t(expr, bit_length);
+SELECT sha2(expr, bit_length) FROM VALUES ('foo',0::INT), ('bar',224::INT), 
('baz',384::INT), ('qux',512::INT), ('qux',128::INT) AS t(expr, bit_length);
 ----
-2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE
-07DAF010DE7F7F0D8D76A76EB8D1EB40182C8D1E7A3877A6686C9BF0
-967004D25DE4ABC1BD6A7C9A216254A5AC0733E8AD96DC9F1EA0FAD9619DA7C32D654EC8AD8BA2F9B5728FED6633BD91
-8C6BE9ED448A34883A13A13F4EAD4AEFA036B67DCDA59020C01E57EA075EA8A4792D428F2C6FD0C09D1C49994D6C22789336E062188DF29572ED07E7F9779C52
+2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
+07daf010de7f7f0d8d76a76eb8d1eb40182c8d1e7a3877a6686c9bf0
+967004d25de4abc1bd6a7c9a216254a5ac0733e8ad96dc9f1ea0fad9619da7c32d654ec8ad8ba2f9b5728fed6633bd91
+8c6be9ed448a34883a13a13f4ead4aefa036b67dcda59020c01e57ea075ea8a4792d428f2c6fd0c09d1c49994d6c22789336e062188df29572ed07e7f9779c52
+NULL


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to