This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 93af7043 fix: Compute murmur3 hash with dictionary input correctly 
(#433)
93af7043 is described below

commit 93af70438b92049226dfd130e04dd83a9863f1a9
Author: advancedxy <[email protected]>
AuthorDate: Fri May 24 23:13:24 2024 +0800

    fix: Compute murmur3 hash with dictionary input correctly (#433)
    
    * fix: Handle compute murmur3 hash with dictionary input correctly
    
    * add unit tests
    
    * spotless apply
    
    * apply scala fix
    
    * address comment
    
    * another style issue
    
    * Update spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
    
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
    
    ---------
    
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
---
 core/src/execution/datafusion/spark_hash.rs        | 270 +++++++++------------
 .../org/apache/comet/CometExpressionSuite.scala    |  58 ++++-
 2 files changed, 163 insertions(+), 165 deletions(-)

diff --git a/core/src/execution/datafusion/spark_hash.rs 
b/core/src/execution/datafusion/spark_hash.rs
index aa4269dd..6d25a72f 100644
--- a/core/src/execution/datafusion/spark_hash.rs
+++ b/core/src/execution/datafusion/spark_hash.rs
@@ -17,7 +17,10 @@
 
 //! This includes utilities for hashing and murmur3 hashing.
 
-use arrow::datatypes::{ArrowNativeTypeOp, UInt16Type, UInt32Type, UInt64Type, 
UInt8Type};
+use arrow::{
+    compute::take,
+    datatypes::{ArrowNativeTypeOp, UInt16Type, UInt32Type, UInt64Type, 
UInt8Type},
+};
 use std::sync::Arc;
 
 use datafusion::{
@@ -95,19 +98,8 @@ pub(crate) fn spark_compatible_murmur3_hash<T: 
AsRef<[u8]>>(data: T, seed: u32)
     }
 }
 
-#[test]
-fn test_murmur3() {
-    let _hashes = ["", "a", "ab", "abc", "abcd", "abcde"]
-        .into_iter()
-        .map(|s| spark_compatible_murmur3_hash(s.as_bytes(), 42) as i32)
-        .collect::<Vec<_>>();
-    let _expected = vec![
-        142593372, 1485273170, -97053317, 1322437556, -396302900, 814637928,
-    ];
-}
-
 macro_rules! hash_array {
-    ($array_type:ident, $column: ident, $hashes: ident) => {
+    ($array_type: ident, $column: ident, $hashes: ident) => {
         let array = $column.as_any().downcast_ref::<$array_type>().unwrap();
         if array.null_count() == 0 {
             for (i, hash) in $hashes.iter_mut().enumerate() {
@@ -123,8 +115,31 @@ macro_rules! hash_array {
     };
 }
 
+macro_rules! hash_array_boolean {
+    ($array_type: ident, $column: ident, $hash_input_type: ident, $hashes: 
ident) => {
+        let array = $column.as_any().downcast_ref::<$array_type>().unwrap();
+        if array.null_count() == 0 {
+            for (i, hash) in $hashes.iter_mut().enumerate() {
+                *hash = spark_compatible_murmur3_hash(
+                    $hash_input_type::from(array.value(i)).to_le_bytes(),
+                    *hash,
+                );
+            }
+        } else {
+            for (i, hash) in $hashes.iter_mut().enumerate() {
+                if !array.is_null(i) {
+                    *hash = spark_compatible_murmur3_hash(
+                        $hash_input_type::from(array.value(i)).to_le_bytes(),
+                        *hash,
+                    );
+                }
+            }
+        }
+    };
+}
+
 macro_rules! hash_array_primitive {
-    ($array_type:ident, $column: ident, $ty: ident, $hashes: ident) => {
+    ($array_type: ident, $column: ident, $ty: ident, $hashes: ident) => {
         let array = $column.as_any().downcast_ref::<$array_type>().unwrap();
         let values = array.values();
 
@@ -143,7 +158,7 @@ macro_rules! hash_array_primitive {
 }
 
 macro_rules! hash_array_primitive_float {
-    ($array_type:ident, $column: ident, $ty: ident, $ty2: ident, $hashes: 
ident) => {
+    ($array_type: ident, $column: ident, $ty: ident, $ty2: ident, $hashes: 
ident) => {
         let array = $column.as_any().downcast_ref::<$array_type>().unwrap();
         let values = array.values();
 
@@ -172,7 +187,7 @@ macro_rules! hash_array_primitive_float {
 }
 
 macro_rules! hash_array_decimal {
-    ($array_type:ident, $column: ident, $hashes: ident) => {
+    ($array_type: ident, $column: ident, $hashes: ident) => {
         let array = $column.as_any().downcast_ref::<$array_type>().unwrap();
 
         if array.null_count() == 0 {
@@ -193,27 +208,33 @@ macro_rules! hash_array_decimal {
 fn create_hashes_dictionary<K: ArrowDictionaryKeyType>(
     array: &ArrayRef,
     hashes_buffer: &mut [u32],
+    first_col: bool,
 ) -> Result<()> {
     let dict_array = 
array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
-
-    // Hash each dictionary value once, and then use that computed
-    // hash for each key value to avoid a potentially expensive
-    // redundant hashing for large dictionary elements (e.g. strings)
-    let dict_values = Arc::clone(dict_array.values());
-    let mut dict_hashes = vec![0; dict_values.len()];
-    create_hashes(&[dict_values], &mut dict_hashes)?;
-
-    for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) {
-        if let Some(key) = key {
-            let idx = key.to_usize().ok_or_else(|| {
-                DataFusionError::Internal(format!(
-                    "Can not convert key value {:?} to usize in dictionary of 
type {:?}",
-                    key,
-                    dict_array.data_type()
-                ))
-            })?;
-            *hash = dict_hashes[idx]
-        } // no update for Null, consistent with other hashes
+    if !first_col {
+        // unpack the dictionary array as each row may have a different hash 
input
+        let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), 
None)?;
+        create_hashes(&[unpacked], hashes_buffer)?;
+    } else {
+        // For the first column, hash each dictionary value once, and then use
+        // that computed hash for each key value to avoid a potentially
+        // expensive redundant hashing for large dictionary elements (e.g. 
strings)
+        let dict_values = Arc::clone(dict_array.values());
+        // same initial seed as Spark
+        let mut dict_hashes = vec![42; dict_values.len()];
+        create_hashes(&[dict_values], &mut dict_hashes)?;
+        for (hash, key) in 
hashes_buffer.iter_mut().zip(dict_array.keys().iter()) {
+            if let Some(key) = key {
+                let idx = key.to_usize().ok_or_else(|| {
+                    DataFusionError::Internal(format!(
+                        "Can not convert key value {:?} to usize in dictionary 
of type {:?}",
+                        key,
+                        dict_array.data_type()
+                    ))
+                })?;
+                *hash = dict_hashes[idx]
+            } // no update for Null, consistent with other hashes
+        }
     }
     Ok(())
 }
@@ -227,27 +248,11 @@ pub fn create_hashes<'a>(
     arrays: &[ArrayRef],
     hashes_buffer: &'a mut [u32],
 ) -> Result<&'a mut [u32]> {
-    for col in arrays {
+    for (i, col) in arrays.iter().enumerate() {
+        let first_col = i == 0;
         match col.data_type() {
             DataType::Boolean => {
-                let array = 
col.as_any().downcast_ref::<BooleanArray>().unwrap();
-                if array.null_count() == 0 {
-                    for (i, hash) in hashes_buffer.iter_mut().enumerate() {
-                        *hash = spark_compatible_murmur3_hash(
-                            i32::from(array.value(i)).to_le_bytes(),
-                            *hash,
-                        );
-                    }
-                } else {
-                    for (i, hash) in hashes_buffer.iter_mut().enumerate() {
-                        if !array.is_null(i) {
-                            *hash = spark_compatible_murmur3_hash(
-                                i32::from(array.value(i)).to_le_bytes(),
-                                *hash,
-                            );
-                        }
-                    }
-                }
+                hash_array_boolean!(BooleanArray, col, i32, hashes_buffer);
             }
             DataType::Int8 => {
                 hash_array_primitive!(Int8Array, col, i32, hashes_buffer);
@@ -305,28 +310,28 @@ pub fn create_hashes<'a>(
             }
             DataType::Dictionary(index_type, _) => match **index_type {
                 DataType::Int8 => {
-                    create_hashes_dictionary::<Int8Type>(col, hashes_buffer)?;
+                    create_hashes_dictionary::<Int8Type>(col, hashes_buffer, 
first_col)?;
                 }
                 DataType::Int16 => {
-                    create_hashes_dictionary::<Int16Type>(col, hashes_buffer)?;
+                    create_hashes_dictionary::<Int16Type>(col, hashes_buffer, 
first_col)?;
                 }
                 DataType::Int32 => {
-                    create_hashes_dictionary::<Int32Type>(col, hashes_buffer)?;
+                    create_hashes_dictionary::<Int32Type>(col, hashes_buffer, 
first_col)?;
                 }
                 DataType::Int64 => {
-                    create_hashes_dictionary::<Int64Type>(col, hashes_buffer)?;
+                    create_hashes_dictionary::<Int64Type>(col, hashes_buffer, 
first_col)?;
                 }
                 DataType::UInt8 => {
-                    create_hashes_dictionary::<UInt8Type>(col, hashes_buffer)?;
+                    create_hashes_dictionary::<UInt8Type>(col, hashes_buffer, 
first_col)?;
                 }
                 DataType::UInt16 => {
-                    create_hashes_dictionary::<UInt16Type>(col, 
hashes_buffer)?;
+                    create_hashes_dictionary::<UInt16Type>(col, hashes_buffer, 
first_col)?;
                 }
                 DataType::UInt32 => {
-                    create_hashes_dictionary::<UInt32Type>(col, 
hashes_buffer)?;
+                    create_hashes_dictionary::<UInt32Type>(col, hashes_buffer, 
first_col)?;
                 }
                 DataType::UInt64 => {
-                    create_hashes_dictionary::<UInt64Type>(col, 
hashes_buffer)?;
+                    create_hashes_dictionary::<UInt64Type>(col, hashes_buffer, 
first_col)?;
                 }
                 _ => {
                     return Err(DataFusionError::Internal(format!(
@@ -363,78 +368,64 @@ mod tests {
     use crate::execution::datafusion::spark_hash::{create_hashes, pmod};
     use datafusion::arrow::array::{ArrayRef, Int32Array, Int64Array, 
Int8Array, StringArray};
 
-    macro_rules! test_hashes {
-        ($ty:ty, $values:expr, $expected:expr) => {
-            let i = Arc::new(<$ty>::from($values)) as ArrayRef;
-            let mut hashes = vec![42; $values.len()];
+    macro_rules! test_hashes_internal {
+        ($input: expr, $len: expr, $expected: expr) => {
+            let i = $input as ArrayRef;
+            let mut hashes = vec![42; $len];
             create_hashes(&[i], &mut hashes).unwrap();
             assert_eq!(hashes, $expected);
         };
     }
 
+    fn test_murmur3_hash<I: Clone, T: arrow_array::Array + 
From<Vec<Option<I>>> + 'static>(
+        values: Vec<Option<I>>,
+        expected: Vec<u32>,
+    ) {
+        // copied before inserting nulls
+        let mut input_with_nulls = values.clone();
+        let mut expected_with_nulls = expected.clone();
+        let len = values.len();
+        let i = Arc::new(T::from(values)) as ArrayRef;
+        test_hashes_internal!(i, len, expected);
+
+        // test with nulls
+        let median = len / 2;
+        input_with_nulls.insert(0, None);
+        input_with_nulls.insert(median, None);
+        expected_with_nulls.insert(0, 42);
+        expected_with_nulls.insert(median, 42);
+        let with_nulls_len = len + 2;
+        let nullable_input = Arc::new(T::from(input_with_nulls)) as ArrayRef;
+        test_hashes_internal!(nullable_input, with_nulls_len, 
expected_with_nulls);
+    }
+
     #[test]
     fn test_i8() {
-        test_hashes!(
-            Int8Array,
+        test_murmur3_hash::<i8, Int8Array>(
             vec![Some(1), Some(0), Some(-1), Some(i8::MAX), Some(i8::MIN)],
-            vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x43b4d8ed, 0x422a1365]
-        );
-        // with null input
-        test_hashes!(
-            Int8Array,
-            vec![Some(1), None, Some(-1), Some(i8::MAX), Some(i8::MIN)],
-            vec![0xdea578e3, 42, 0xa0590e3d, 0x43b4d8ed, 0x422a1365]
+            vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x43b4d8ed, 0x422a1365],
         );
     }
 
     #[test]
     fn test_i32() {
-        test_hashes!(
-            Int32Array,
+        test_murmur3_hash::<i32, Int32Array>(
             vec![Some(1), Some(0), Some(-1), Some(i32::MAX), Some(i32::MIN)],
-            vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x07fb67e7, 0x2b1f0fc6]
-        );
-        // with null input
-        test_hashes!(
-            Int32Array,
-            vec![
-                Some(1),
-                Some(0),
-                Some(-1),
-                None,
-                Some(i32::MAX),
-                Some(i32::MIN)
-            ],
-            vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 42, 0x07fb67e7, 
0x2b1f0fc6]
+            vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x07fb67e7, 0x2b1f0fc6],
         );
     }
 
     #[test]
     fn test_i64() {
-        test_hashes!(
-            Int64Array,
+        test_murmur3_hash::<i64, Int64Array>(
             vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)],
-            vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 0xa05b5d7b, 0xcd1e64fb]
-        );
-        // with null input
-        test_hashes!(
-            Int64Array,
-            vec![
-                Some(1),
-                Some(0),
-                Some(-1),
-                None,
-                Some(i64::MAX),
-                Some(i64::MIN)
-            ],
-            vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 42, 0xa05b5d7b, 
0xcd1e64fb]
+            vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 0xa05b5d7b, 0xcd1e64fb],
         );
     }
 
     #[test]
     fn test_f32() {
-        test_hashes!(
-            Float32Array,
+        test_murmur3_hash::<f32, Float32Array>(
             vec![
                 Some(1.0),
                 Some(0.0),
@@ -443,28 +434,15 @@ mod tests {
                 Some(99999999999.99999999999),
                 Some(-99999999999.99999999999),
             ],
-            vec![0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 0xcbdc340f, 
0xc0361c86]
-        );
-        // with null input
-        test_hashes!(
-            Float32Array,
             vec![
-                Some(1.0),
-                Some(0.0),
-                Some(-0.0),
-                Some(-1.0),
-                None,
-                Some(99999999999.99999999999),
-                Some(-99999999999.99999999999)
+                0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 0xcbdc340f, 
0xc0361c86,
             ],
-            vec![0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 42, 
0xcbdc340f, 0xc0361c86]
         );
     }
 
     #[test]
     fn test_f64() {
-        test_hashes!(
-            Float64Array,
+        test_murmur3_hash::<f64, Float64Array>(
             vec![
                 Some(1.0),
                 Some(0.0),
@@ -473,44 +451,26 @@ mod tests {
                 Some(99999999999.99999999999),
                 Some(-99999999999.99999999999),
             ],
-            vec![0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 0xb87e1595, 
0xa0eef9f9]
-        );
-        // with null input
-        test_hashes!(
-            Float64Array,
             vec![
-                Some(1.0),
-                Some(0.0),
-                Some(-0.0),
-                Some(-1.0),
-                None,
-                Some(99999999999.99999999999),
-                Some(-99999999999.99999999999)
+                0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 0xb87e1595, 
0xa0eef9f9,
             ],
-            vec![0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 42, 
0xb87e1595, 0xa0eef9f9]
         );
     }
 
     #[test]
     fn test_str() {
-        test_hashes!(
-            StringArray,
-            vec!["hello", "bar", "", "😁", "天地"],
-            vec![3286402344, 2486176763, 142593372, 885025535, 2395000894]
-        );
-        // test with null input
-        test_hashes!(
-            StringArray,
-            vec![
-                Some("hello"),
-                Some("bar"),
-                None,
-                Some(""),
-                Some("😁"),
-                Some("天地")
-            ],
-            vec![3286402344, 2486176763, 42, 142593372, 885025535, 2395000894]
-        );
+        let input = vec![
+            "hello", "bar", "", "😁", "天地", "a", "ab", "abc", "abcd", "abcde",
+        ]
+        .iter()
+        .map(|s| Some(s.to_string()))
+        .collect::<Vec<Option<String>>>();
+        let expected: Vec<u32> = vec![
+            3286402344, 2486176763, 142593372, 885025535, 2395000894, 
1485273170, 0xfa37157b,
+            1322437556, 0xe860e5cc, 814637928,
+        ];
+
+        test_murmur3_hash::<String, StringArray>(input, expected);
     }
 
     #[test]
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 6ca4baf6..99261508 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -1454,17 +1454,55 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
         withTable(table) {
           sql(s"create table $table(col string, a int, b float) using parquet")
           sql(s"""
-             |insert into $table values
-             |('Spark SQL  ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), 
('苹果手机', NULL, 3.999999)
-             |, ('Spark SQL  ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), 
('苹果手机', NULL, 3.999999)
-             |""".stripMargin)
+              |insert into $table values
+              |('Spark SQL  ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), 
('苹果手机', NULL, 3.999999)
+              |, ('Spark SQL  ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), 
('苹果手机', NULL, 3.999999)
+              |""".stripMargin)
           checkSparkAnswerAndOperator("""
-               |select
-               |md5(col), md5(cast(a as string)), md5(cast(b as string)),
-               |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), 
hash(b, a, col),
-               |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), 
sha2(col, 512), sha2(col, 128)
-               |from test
-               |""".stripMargin)
+              |select
+              |md5(col), md5(cast(a as string)), md5(cast(b as string)),
+              |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, 
a, col),
+              |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), 
sha2(col, 512), sha2(col, 128)
+              |from test
+              |""".stripMargin)
+        }
+      }
+    }
+  }
+
+  test("hash functions with random input") {
+    val dataGen = DataGenerator.DEFAULT
+    // sufficient number of rows to create dictionary encoded ArrowArray.
+    val randomNumRows = 1000
+
+    val whitespaceChars = " \t\r\n"
+    val timestampPattern = "0123456789/:T" + whitespaceChars
+    Seq(true, false).foreach { dictionary =>
+      withSQLConf(
+        "parquet.enable.dictionary" -> dictionary.toString,
+        CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") {
+        val table = "test"
+        withTable(table) {
+          sql(s"create table $table(col string, a int, b float) using parquet")
+          // TODO: Add a Row generator in the data gen class and replace th 
following code
+          val col = dataGen.generateStrings(randomNumRows, timestampPattern, 6)
+          val colA = dataGen.generateInts(randomNumRows)
+          val colB = dataGen.generateFloats(randomNumRows)
+          val data = col.zip(colA).zip(colB).map { case ((a, b), c) => (a, b, 
c) }
+          data
+            .toDF("col", "a", "b")
+            .write
+            .mode("append")
+            .insertInto(table)
+          // with random generated data
+          // disable cast(b as string) for now, as the cast from float to 
string may produce incompatible result
+          checkSparkAnswerAndOperator("""
+              |select
+              |md5(col), md5(cast(a as string)), --md5(cast(b as string)),
+              |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, 
a, col),
+              |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), 
sha2(col, 512), sha2(col, 128)
+              |from test
+              |""".stripMargin)
         }
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to