This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 1103a6c1 feat: Support more types with BloomFilterAgg (#1039)
1103a6c1 is described below

commit 1103a6c193b92cf0ce655b7e18a5a44aa4332476
Author: Matt Butrovich <[email protected]>
AuthorDate: Thu Oct 31 15:14:59 2024 -0400

    feat: Support more types with BloomFilterAgg (#1039)
    
    * Add support for other integer types. Next up is strings.
    
    * Test refactor.
    
    * Add string support for agg. Should probably go back and add support to 
BloomFilterMightContain.
    
    * Fix typo.
    
    * Minor test refactor to reduce copy-paste.
    
    * Add type checking in plan conversion, and fix signature on native side.
---
 .../datafusion/expressions/bloom_filter_agg.rs     | 35 ++++++++++++++++++----
 .../datafusion/util/spark_bloom_filter.rs          | 17 +++++++++++
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 14 ++++++---
 .../org/apache/comet/exec/CometExecSuite.scala     |  8 +++--
 4 files changed, 63 insertions(+), 11 deletions(-)

diff --git 
a/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs 
b/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
index ed64b80e..e6528a56 100644
--- a/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
+++ b/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
@@ -62,7 +62,17 @@ impl BloomFilterAgg {
         assert!(matches!(data_type, DataType::Binary));
         Self {
             name: name.into(),
-            signature: Signature::exact(vec![DataType::Int64], 
Volatility::Immutable),
+            signature: Signature::uniform(
+                1,
+                vec![
+                    DataType::Int8,
+                    DataType::Int16,
+                    DataType::Int32,
+                    DataType::Int64,
+                    DataType::Utf8,
+                ],
+                Volatility::Immutable,
+            ),
             expr,
             num_items: extract_i32_from_literal(num_items),
             num_bits: extract_i32_from_literal(num_bits),
@@ -112,10 +122,25 @@ impl Accumulator for SparkBloomFilter {
         (0..arr.len()).try_for_each(|index| {
             let v = ScalarValue::try_from_array(arr, index)?;
 
-            if let ScalarValue::Int64(Some(value)) = v {
-                self.put_long(value);
-            } else {
-                unreachable!()
+            match v {
+                ScalarValue::Int8(Some(value)) => {
+                    self.put_long(value as i64);
+                }
+                ScalarValue::Int16(Some(value)) => {
+                    self.put_long(value as i64);
+                }
+                ScalarValue::Int32(Some(value)) => {
+                    self.put_long(value as i64);
+                }
+                ScalarValue::Int64(Some(value)) => {
+                    self.put_long(value);
+                }
+                ScalarValue::Utf8(Some(value)) => {
+                    self.put_binary(value.as_bytes());
+                }
+                _ => {
+                    unreachable!()
+                }
             }
             Ok(())
         })
diff --git a/native/core/src/execution/datafusion/util/spark_bloom_filter.rs 
b/native/core/src/execution/datafusion/util/spark_bloom_filter.rs
index 22a84d85..35fa23b4 100644
--- a/native/core/src/execution/datafusion/util/spark_bloom_filter.rs
+++ b/native/core/src/execution/datafusion/util/spark_bloom_filter.rs
@@ -115,6 +115,23 @@ impl SparkBloomFilter {
         bit_changed
     }
 
+    pub fn put_binary(&mut self, item: &[u8]) -> bool {
+        // Here we first hash the input long element into 2 int hash values, 
h1 and h2, then produce
+        // n hash values by `h1 + i * h2` with 1 <= i <= num_hash_functions.
+        let h1 = spark_compatible_murmur3_hash(item, 0);
+        let h2 = spark_compatible_murmur3_hash(item, h1);
+        let bit_size = self.bits.bit_size() as i32;
+        let mut bit_changed = false;
+        for i in 1..=self.num_hash_functions {
+            let mut combined_hash = (h1 as i32).add_wrapping((i as 
i32).mul_wrapping(h2 as i32));
+            if combined_hash < 0 {
+                combined_hash = !combined_hash;
+            }
+            bit_changed |= self.bits.set((combined_hash % bit_size) as usize)
+        }
+        bit_changed
+    }
+
     pub fn might_contain_long(&self, item: i64) -> bool {
         let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0);
         let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1);
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 3805d418..abb138b0 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -769,11 +769,17 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
         val numBitsExpr = exprToProto(numBits, inputs, binding)
         val dataType = serializeDataType(bloom_filter.dataType)
 
-        // TODO: Support more types
-        //  https://github.com/apache/datafusion-comet/issues/1023
         if (childExpr.isDefined &&
-          child.dataType
-            .isInstanceOf[LongType] &&
+          (child.dataType
+            .isInstanceOf[ByteType] ||
+            child.dataType
+              .isInstanceOf[ShortType] ||
+            child.dataType
+              .isInstanceOf[IntegerType] ||
+            child.dataType
+              .isInstanceOf[LongType] ||
+            child.dataType
+              .isInstanceOf[StringType]) &&
           numItemsExpr.isDefined &&
           numBitsExpr.isDefined &&
           dataType.isDefined) {
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index a720842c..99007d0c 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -946,8 +946,12 @@ class CometExecSuite extends CometTestBase {
       (0 until 100)
         .map(_ => (Random.nextInt(), Random.nextInt() % 5)),
       "tbl") {
-      val df = sql("SELECT bloom_filter_agg(cast(_2 as long)) FROM tbl")
-      checkSparkAnswerAndOperator(df)
+
+      (if (isSpark35Plus) Seq("tinyint", "short", "int", "long", "string") 
else Seq("long"))
+        .foreach { input_type =>
+          val df = sql(f"SELECT bloom_filter_agg(cast(_2 as $input_type)) FROM 
tbl")
+          checkSparkAnswerAndOperator(df)
+        }
     }
 
     spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to