This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 969f683  feat: Support BloomFilterMightContain expr (#179)
969f683 is described below

commit 969f683c188fbcf2956da784b10777445f9adbc0
Author: advancedxy <xian...@apache.org>
AuthorDate: Fri Mar 15 05:24:16 2024 +0800

    feat: Support BloomFilterMightContain expr (#179)
---
 core/src/common/bit.rs                             |  11 ++
 .../expressions/bloom_filter_might_contain.rs      | 152 +++++++++++++++++++++
 core/src/execution/datafusion/expressions/mod.rs   |   1 +
 core/src/execution/datafusion/mod.rs               |   1 +
 core/src/execution/datafusion/planner.rs           |  10 ++
 core/src/execution/datafusion/spark_hash.rs        |   2 +-
 core/src/execution/datafusion/{ => util}/mod.rs    |   9 +-
 .../execution/datafusion/util/spark_bit_array.rs   | 131 ++++++++++++++++++
 .../datafusion/util/spark_bloom_filter.rs          |  98 +++++++++++++
 core/src/execution/proto/expr.proto                |   6 +
 pom.xml                                            |  10 ++
 spark/pom.xml                                      |  18 +++
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  17 +++
 .../apache/comet/shims/ShimQueryPlanSerde.scala    |   7 +-
 .../apache/comet/CometExpression3_3PlusSuite.scala | 106 ++++++++++++++
 15 files changed, 570 insertions(+), 9 deletions(-)

diff --git a/core/src/common/bit.rs b/core/src/common/bit.rs
index 4af560f..f736347 100644
--- a/core/src/common/bit.rs
+++ b/core/src/common/bit.rs
@@ -131,6 +131,17 @@ pub fn read_num_bytes_u32(size: usize, src: &[u8]) -> u32 {
     trailing_bits(v as u64, size * 8) as u32
 }
 
+/// Similar to the `read_num_bytes` but read nums from bytes in big-endian 
order
+/// This is used to read bytes from Java's OutputStream which writes bytes in 
big-endian
+macro_rules! read_num_be_bytes {
+    ($ty:ty, $size:expr, $src:expr) => {{
+        debug_assert!($size <= $src.len());
+        let mut buffer = <$ty as 
$crate::common::bit::FromBytes>::Buffer::default();
+        buffer.as_mut()[..$size].copy_from_slice(&$src[..$size]);
+        <$ty>::from_be_bytes(buffer)
+    }};
+}
+
 /// Converts value `val` of type `T` to a byte vector, by reading `num_bytes` 
from `val`.
 /// NOTE: if `val` is less than the size of `T` then it can be truncated.
 #[inline]
diff --git 
a/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs 
b/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs
new file mode 100644
index 0000000..dd90cd8
--- /dev/null
+++ b/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs
@@ -0,0 +1,152 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::{
+    execution::datafusion::util::spark_bloom_filter::SparkBloomFilter, 
parquet::data_type::AsBytes,
+};
+use arrow::record_batch::RecordBatch;
+use arrow_array::cast::as_primitive_array;
+use arrow_schema::{DataType, Schema};
+use datafusion::physical_plan::ColumnarValue;
+use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
+use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, 
PhysicalExpr};
+use std::{
+    any::Any,
+    fmt::Display,
+    hash::{Hash, Hasher},
+    sync::Arc,
+};
+
+/// A physical expression that checks if a value might be in a bloom filter. 
It corresponds to the
+/// Spark's `BloomFilterMightContain` expression.
+#[derive(Debug, Hash)]
+pub struct BloomFilterMightContain {
+    pub bloom_filter_expr: Arc<dyn PhysicalExpr>,
+    pub value_expr: Arc<dyn PhysicalExpr>,
+    bloom_filter: Option<SparkBloomFilter>,
+}
+
+impl Display for BloomFilterMightContain {
+    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+        write!(
+            f,
+            "BloomFilterMightContain [bloom_filter_expr: {}, value_expr: {}]",
+            self.bloom_filter_expr, self.value_expr
+        )
+    }
+}
+
+impl PartialEq<dyn Any> for BloomFilterMightContain {
+    fn eq(&self, _other: &dyn Any) -> bool {
+        down_cast_any_ref(_other)
+            .downcast_ref::<Self>()
+            .map(|other| {
+                self.bloom_filter_expr.eq(&other.bloom_filter_expr)
+                    && self.value_expr.eq(&other.value_expr)
+            })
+            .unwrap_or(false)
+    }
+}
+
+fn evaluate_bloom_filter(
+    bloom_filter_expr: &Arc<dyn PhysicalExpr>,
+) -> Result<Option<SparkBloomFilter>> {
+    // bloom_filter_expr must be a literal/scalar subquery expression, so we 
can evaluate it
+    // with an empty batch with empty schema
+    let batch = RecordBatch::new_empty(Arc::new(Schema::empty()));
+    let bloom_filter_bytes = bloom_filter_expr.evaluate(&batch)?;
+    match bloom_filter_bytes {
+        ColumnarValue::Scalar(ScalarValue::Binary(v)) => {
+            Ok(v.map(|v| SparkBloomFilter::new(v.as_bytes())))
+        }
+        _ => internal_err!("Bloom filter expression should be evaluated as a 
scalar binary value"),
+    }
+}
+
+impl BloomFilterMightContain {
+    pub fn try_new(
+        bloom_filter_expr: Arc<dyn PhysicalExpr>,
+        value_expr: Arc<dyn PhysicalExpr>,
+    ) -> Result<Self> {
+        // early evaluate the bloom_filter_expr to get the actual bloom filter
+        let bloom_filter = evaluate_bloom_filter(&bloom_filter_expr)?;
+        Ok(Self {
+            bloom_filter_expr,
+            value_expr,
+            bloom_filter,
+        })
+    }
+}
+
+impl PhysicalExpr for BloomFilterMightContain {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
+        Ok(DataType::Boolean)
+    }
+
+    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+        Ok(true)
+    }
+
+    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
+        self.bloom_filter
+            .as_ref()
+            .map(|spark_filter| {
+                let values = self.value_expr.evaluate(batch)?;
+                match values {
+                    ColumnarValue::Array(array) => {
+                        let boolean_array =
+                            
spark_filter.might_contain_longs(as_primitive_array(&array));
+                        Ok(ColumnarValue::Array(Arc::new(boolean_array)))
+                    }
+                    ColumnarValue::Scalar(ScalarValue::Int64(v)) => {
+                        let result = v.map(|v| 
spark_filter.might_contain_long(v));
+                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result)))
+                    }
+                    _ => internal_err!("value expression should be int64 
type"),
+                }
+            })
+            .unwrap_or_else(|| {
+                // when the bloom filter is null, we should return null for 
all the input
+                Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)))
+            })
+    }
+
+    fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        vec![self.bloom_filter_expr.clone(), self.value_expr.clone()]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        children: Vec<Arc<dyn PhysicalExpr>>,
+    ) -> Result<Arc<dyn PhysicalExpr>> {
+        Ok(Arc::new(BloomFilterMightContain::try_new(
+            children[0].clone(),
+            children[1].clone(),
+        )?))
+    }
+
+    fn dyn_hash(&self, state: &mut dyn Hasher) {
+        let mut s = state;
+        self.bloom_filter_expr.hash(&mut s);
+        self.value_expr.hash(&mut s);
+        self.hash(&mut s);
+    }
+}
diff --git a/core/src/execution/datafusion/expressions/mod.rs 
b/core/src/execution/datafusion/expressions/mod.rs
index cfc3125..69cdf3e 100644
--- a/core/src/execution/datafusion/expressions/mod.rs
+++ b/core/src/execution/datafusion/expressions/mod.rs
@@ -26,6 +26,7 @@ pub mod scalar_funcs;
 pub use normalize_nan::NormalizeNaNAndZero;
 pub mod avg;
 pub mod avg_decimal;
+pub mod bloom_filter_might_contain;
 pub mod strings;
 pub mod subquery;
 pub mod sum_decimal;
diff --git a/core/src/execution/datafusion/mod.rs 
b/core/src/execution/datafusion/mod.rs
index f9fafeb..c464eee 100644
--- a/core/src/execution/datafusion/mod.rs
+++ b/core/src/execution/datafusion/mod.rs
@@ -22,3 +22,4 @@ mod operators;
 pub mod planner;
 pub(crate) mod shuffle_writer;
 mod spark_hash;
+mod util;
diff --git a/core/src/execution/datafusion/planner.rs 
b/core/src/execution/datafusion/planner.rs
index d52ec80..57c126b 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -56,6 +56,7 @@ use crate::{
                 avg::Avg,
                 avg_decimal::AvgDecimal,
                 bitwise_not::BitwiseNotExpr,
+                bloom_filter_might_contain::BloomFilterMightContain,
                 cast::Cast,
                 checkoverflow::CheckOverflow,
                 if_expr::IfExpr,
@@ -534,6 +535,15 @@ impl PhysicalPlanner {
                 let data_type = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
                 Ok(Arc::new(Subquery::new(self.exec_context_id, id, 
data_type)))
             }
+            ExprStruct::BloomFilterMightContain(expr) => {
+                let bloom_filter_expr =
+                    self.create_expr(expr.bloom_filter.as_ref().unwrap(), 
input_schema.clone())?;
+                let value_expr = 
self.create_expr(expr.value.as_ref().unwrap(), input_schema)?;
+                Ok(Arc::new(BloomFilterMightContain::try_new(
+                    bloom_filter_expr,
+                    value_expr,
+                )?))
+            }
             expr => Err(ExecutionError::GeneralError(format!(
                 "Not implemented: {:?}",
                 expr
diff --git a/core/src/execution/datafusion/spark_hash.rs 
b/core/src/execution/datafusion/spark_hash.rs
index aeefccf..1d8d1f2 100644
--- a/core/src/execution/datafusion/spark_hash.rs
+++ b/core/src/execution/datafusion/spark_hash.rs
@@ -32,7 +32,7 @@ use datafusion::{
 };
 
 #[inline]
-fn spark_compatible_murmur3_hash<T: AsRef<[u8]>>(data: T, seed: u32) -> u32 {
+pub(crate) fn spark_compatible_murmur3_hash<T: AsRef<[u8]>>(data: T, seed: 
u32) -> u32 {
     #[inline]
     fn mix_k1(mut k1: i32) -> i32 {
         k1 = k1.mul_wrapping(0xcc9e2d51u32 as i32);
diff --git a/core/src/execution/datafusion/mod.rs 
b/core/src/execution/datafusion/util/mod.rs
similarity index 85%
copy from core/src/execution/datafusion/mod.rs
copy to core/src/execution/datafusion/util/mod.rs
index f9fafeb..75b763a 100644
--- a/core/src/execution/datafusion/mod.rs
+++ b/core/src/execution/datafusion/util/mod.rs
@@ -15,10 +15,5 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Native execution through DataFusion
-
-mod expressions;
-mod operators;
-pub mod planner;
-pub(crate) mod shuffle_writer;
-mod spark_hash;
+pub mod spark_bit_array;
+pub mod spark_bloom_filter;
diff --git a/core/src/execution/datafusion/util/spark_bit_array.rs 
b/core/src/execution/datafusion/util/spark_bit_array.rs
new file mode 100644
index 0000000..9729627
--- /dev/null
+++ b/core/src/execution/datafusion/util/spark_bit_array.rs
@@ -0,0 +1,131 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// A simple bit array implementation that simulates the behavior of Spark's 
BitArray which is
+/// used in the BloomFilter implementation. Some methods are not implemented 
as they are not
+/// required for the current use case.
+#[derive(Debug, Hash)]
+pub struct SparkBitArray {
+    data: Vec<u64>,
+    bit_count: usize,
+}
+
+impl SparkBitArray {
+    pub fn new(buf: Vec<u64>) -> Self {
+        let num_bits = buf.iter().map(|x| x.count_ones() as usize).sum();
+        Self {
+            data: buf,
+            bit_count: num_bits,
+        }
+    }
+
+    pub fn set(&mut self, index: usize) -> bool {
+        if !self.get(index) {
+            // see the get method for the explanation of the shift operators
+            self.data[index >> 6] |= 1u64 << (index & 0x3f);
+            self.bit_count += 1;
+            true
+        } else {
+            false
+        }
+    }
+
+    pub fn get(&self, index: usize) -> bool {
+        // Java version: (data[(int) (index >> 6)] & (1L << (index))) != 0
+        // Rust and Java have different semantics for the shift operators. 
Java's shift operators
+        // explicitly mask the right-hand operand with 0x3f [1], while Rust's 
shift operators does
+        // not do this, it will panic with shift left with overflow for large 
right-hand operand.
+        // To fix this, we need to mask the right-hand operand with 0x3f in 
the rust side.
+        // [1]: 
https://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.19
+        (self.data[index >> 6] & (1u64 << (index & 0x3f))) != 0
+    }
+
+    pub fn bit_size(&self) -> u64 {
+        self.data.len() as u64 * 64
+    }
+
+    pub fn cardinality(&self) -> usize {
+        self.bit_count
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+
+    #[test]
+    fn test_spark_bit_array() {
+        let buf = vec![0u64; 4];
+        let mut array = SparkBitArray::new(buf);
+        assert_eq!(array.bit_size(), 256);
+        assert_eq!(array.cardinality(), 0);
+
+        assert!(!array.get(0));
+        assert!(!array.get(1));
+        assert!(!array.get(63));
+        assert!(!array.get(64));
+        assert!(!array.get(65));
+        assert!(!array.get(127));
+        assert!(!array.get(128));
+        assert!(!array.get(129));
+
+        assert!(array.set(0));
+        assert!(array.set(1));
+        assert!(array.set(63));
+        assert!(array.set(64));
+        assert!(array.set(65));
+        assert!(array.set(127));
+        assert!(array.set(128));
+        assert!(array.set(129));
+
+        assert_eq!(array.cardinality(), 8);
+        assert_eq!(array.bit_size(), 256);
+
+        assert!(array.get(0));
+        // already set so should return false
+        assert!(!array.set(0));
+
+        // not set values should return false for get
+        assert!(!array.get(2));
+        assert!(!array.get(62));
+    }
+
+    #[test]
+    fn test_spark_bit_with_non_empty_buffer() {
+        let buf = vec![8u64; 4];
+        let mut array = SparkBitArray::new(buf);
+        assert_eq!(array.bit_size(), 256);
+        assert_eq!(array.cardinality(), 4);
+
+        // already set bits should return true
+        assert!(array.get(3));
+        assert!(array.get(67));
+        assert!(array.get(131));
+        assert!(array.get(195));
+
+        // other unset bits should return false
+        assert!(!array.get(0));
+        assert!(!array.get(1));
+
+        // set bits
+        assert!(array.set(0));
+        assert!(array.set(1));
+
+        // check cardinality
+        assert_eq!(array.cardinality(), 6);
+    }
+}
diff --git a/core/src/execution/datafusion/util/spark_bloom_filter.rs 
b/core/src/execution/datafusion/util/spark_bloom_filter.rs
new file mode 100644
index 0000000..22957a1
--- /dev/null
+++ b/core/src/execution/datafusion/util/spark_bloom_filter.rs
@@ -0,0 +1,98 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::execution::datafusion::{
+    spark_hash::spark_compatible_murmur3_hash, 
util::spark_bit_array::SparkBitArray,
+};
+use arrow_array::{ArrowNativeTypeOp, BooleanArray, Int64Array};
+
+const SPARK_BLOOM_FILTER_VERSION_1: i32 = 1;
+
+/// A Bloom filter implementation that simulates the behavior of Spark's 
BloomFilter.
+/// It's not a complete implementation of Spark's BloomFilter, but just add 
the minimum
+/// methods to support mightContainsLong in the native side.
+#[derive(Debug, Hash)]
+pub struct SparkBloomFilter {
+    bits: SparkBitArray,
+    num_hash_functions: u32,
+}
+
+impl SparkBloomFilter {
+    pub fn new(buf: &[u8]) -> Self {
+        let mut offset = 0;
+        let version = read_num_be_bytes!(i32, 4, buf[offset..]);
+        offset += 4;
+        assert_eq!(
+            version, SPARK_BLOOM_FILTER_VERSION_1,
+            "Unsupported BloomFilter version: {}, expecting version: {}",
+            version, SPARK_BLOOM_FILTER_VERSION_1
+        );
+        let num_hash_functions = read_num_be_bytes!(i32, 4, buf[offset..]);
+        offset += 4;
+        let num_words = read_num_be_bytes!(i32, 4, buf[offset..]);
+        offset += 4;
+        let mut bits = vec![0u64; num_words as usize];
+        for i in 0..num_words {
+            bits[i as usize] = read_num_be_bytes!(i64, 8, buf[offset..]) as 
u64;
+            offset += 8;
+        }
+        Self {
+            bits: SparkBitArray::new(bits),
+            num_hash_functions: num_hash_functions as u32,
+        }
+    }
+
+    pub fn put_long(&mut self, item: i64) -> bool {
+        // Here we first hash the input long element into 2 int hash values, 
h1 and h2, then produce
+        // n hash values by `h1 + i * h2` with 1 <= i <= num_hash_functions.
+        let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0);
+        let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1);
+        let bit_size = self.bits.bit_size() as i32;
+        let mut bit_changed = false;
+        for i in 1..=self.num_hash_functions {
+            let mut combined_hash = (h1 as i32).add_wrapping((i as 
i32).mul_wrapping(h2 as i32));
+            if combined_hash < 0 {
+                combined_hash = !combined_hash;
+            }
+            bit_changed |= self.bits.set((combined_hash % bit_size) as usize)
+        }
+        bit_changed
+    }
+
+    pub fn might_contain_long(&self, item: i64) -> bool {
+        let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0);
+        let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1);
+        let bit_size = self.bits.bit_size() as i32;
+        for i in 1..=self.num_hash_functions {
+            let mut combined_hash = (h1 as i32).add_wrapping((i as 
i32).mul_wrapping(h2 as i32));
+            if combined_hash < 0 {
+                combined_hash = !combined_hash;
+            }
+            if !self.bits.get((combined_hash % bit_size) as usize) {
+                return false;
+            }
+        }
+        true
+    }
+
+    pub fn might_contain_longs(&self, items: &Int64Array) -> BooleanArray {
+        items
+            .iter()
+            .map(|v| v.map(|x| self.might_contain_long(x)))
+            .collect()
+    }
+}
diff --git a/core/src/execution/proto/expr.proto 
b/core/src/execution/proto/expr.proto
index e8d35d1..58f607f 100644
--- a/core/src/execution/proto/expr.proto
+++ b/core/src/execution/proto/expr.proto
@@ -76,6 +76,7 @@ message Expr {
     Abs abs = 49;
     Subquery subquery = 50;
     UnboundReference unbound = 51;
+    BloomFilterMightContain bloom_filter_might_contain = 52;
   }
 }
 
@@ -432,6 +433,11 @@ message Subquery {
   DataType datatype = 2;
 }
 
+message BloomFilterMightContain {
+  Expr bloom_filter = 1;
+  Expr value = 2;
+}
+
 enum SortDirection {
   Ascending = 0;
   Descending = 1;
diff --git a/pom.xml b/pom.xml
index 7077aa3..bdc55ec 100644
--- a/pom.xml
+++ b/pom.xml
@@ -86,6 +86,7 @@ under the License.
       -Djdk.reflect.useDirectMethodHandle=false
     </extraJavaTestArgs>
     <argLine>-ea -Xmx4g -Xss4m ${extraJavaTestArgs}</argLine>
+    <additional.test.source>spark-3.3-plus</additional.test.source>
   </properties>
 
   <dependencyManagement>
@@ -494,6 +495,8 @@ under the License.
         <spark.version>3.2.2</spark.version>
         <spark.version.short>3.2</spark.version.short>
         <parquet.version>1.12.0</parquet.version>
+        <!-- we don't add special test suits for spark-3.2, so a not existed 
dir is specified-->
+        <additional.test.source>not-needed-yet</additional.test.source>
       </properties>
     </profile>
 
@@ -504,6 +507,7 @@ under the License.
         <spark.version>3.3.2</spark.version>
         <spark.version.short>3.3</spark.version.short>
         <parquet.version>1.12.0</parquet.version>
+        <additional.test.source>spark-3.3-plus</additional.test.source>
       </properties>
     </profile>
 
@@ -513,6 +517,7 @@ under the License.
         <scala.version>2.12.17</scala.version>
         <spark.version.short>3.4</spark.version.short>
         <parquet.version>1.13.1</parquet.version>
+        <additional.test.source>spark-3.3-plus</additional.test.source>
       </properties>
     </profile>
 
@@ -777,6 +782,11 @@ under the License.
           <artifactId>jacoco-maven-plugin</artifactId>
           <version>${jacoco.version}</version>
         </plugin>
+        <plugin>
+          <groupId>org.codehaus.mojo</groupId>
+          <artifactId>build-helper-maven-plugin</artifactId>
+          <version>3.2.0</version>
+        </plugin>
       </plugins>
     </pluginManagement>
     <plugins>
diff --git a/spark/pom.xml b/spark/pom.xml
index 7e54fde..31d80bb 100644
--- a/spark/pom.xml
+++ b/spark/pom.xml
@@ -233,6 +233,24 @@ under the License.
         <groupId>net.alchim31.maven</groupId>
         <artifactId>scala-maven-plugin</artifactId>
       </plugin>
+      <plugin>
+        <groupId>org.codehaus.mojo</groupId>
+        <artifactId>build-helper-maven-plugin</artifactId>
+        <executions>
+          <execution>
+            <id>add-test-source</id>
+            <phase>generate-test-sources</phase>
+            <goals>
+              <goal>add-test-source</goal>
+            </goals>
+            <configuration>
+              <sources>
+                <source>src/test/${additional.test.source}</source>
+              </sources>
+            </configuration>
+          </execution>
+        </executions>
+      </plugin>
     </plugins>
   </build>
 
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 87b4dff..abb4f0a 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1626,6 +1626,23 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
             "make_decimal",
             DecimalType(precision, scale),
             childExpr)
+        case b @ BinaryExpression(_, _) if isBloomFilterMightContain(b) =>
+          val bloomFilter = b.left
+          val value = b.right
+          val bloomFilterExpr = exprToProtoInternal(bloomFilter, inputs)
+          val valueExpr = exprToProtoInternal(value, inputs)
+          if (bloomFilterExpr.isDefined && valueExpr.isDefined) {
+            val builder = ExprOuterClass.BloomFilterMightContain.newBuilder()
+            builder.setBloomFilter(bloomFilterExpr.get)
+            builder.setValue(valueExpr.get)
+            Some(
+              ExprOuterClass.Expr
+                .newBuilder()
+                .setBloomFilterMightContain(builder)
+                .build())
+          } else {
+            None
+          }
 
         case e =>
           emitWarning(s"unsupported Spark expression: '$e' of class 
'${e.getClass.getName}")
diff --git 
a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala
index c47b399..7bdf2c0 100644
--- a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala
@@ -19,7 +19,7 @@
 
 package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions.BinaryArithmetic
+import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, 
BinaryExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
 
 trait ShimQueryPlanSerde {
@@ -44,4 +44,9 @@ trait ShimQueryPlanSerde {
       failOnError.head
     }
   }
+
+  // TODO: delete after drop Spark 3.2 support
+  def isBloomFilterMightContain(binary: BinaryExpression): Boolean = {
+    binary.getClass.getName == 
"org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain"
+  }
 }
diff --git 
a/spark/src/test/spark-3.3-plus/org/apache/comet/CometExpression3_3PlusSuite.scala
 
b/spark/src/test/spark-3.3-plus/org/apache/comet/CometExpression3_3PlusSuite.scala
new file mode 100644
index 0000000..6102777
--- /dev/null
+++ 
b/spark/src/test/spark-3.3-plus/org/apache/comet/CometExpression3_3PlusSuite.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet
+
+import org.apache.spark.sql.{Column, CometTestBase}
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.expressions.{BloomFilterMightContain, 
Expression, ExpressionInfo}
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.util.sketch.BloomFilter
+
+import java.io.ByteArrayOutputStream
+import scala.util.Random
+
+class CometExpression3_3PlusSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
+  import testImplicits._
+
+  val func_might_contain = new FunctionIdentifier("might_contain")
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    // Register 'might_contain' to builtin.
+    spark.sessionState.functionRegistry.registerFunction(func_might_contain,
+      new ExpressionInfo(classOf[BloomFilterMightContain].getName, 
"might_contain"),
+      (children: Seq[Expression]) => BloomFilterMightContain(children.head, 
children(1)))
+  }
+
+  override def afterAll(): Unit = {
+    spark.sessionState.functionRegistry.dropFunction(func_might_contain)
+    super.afterAll()
+  }
+
+  test("test BloomFilterMightContain can take a constant value input") {
+    val table = "test"
+
+    withTable(table) {
+      sql(s"create table $table(col1 long, col2 int) using parquet")
+      sql(s"insert into $table values (201, 1)")
+      checkSparkAnswerAndOperator(
+        s"""
+          |SELECT might_contain(
+          
|X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267', 
col1) FROM $table
+          |""".stripMargin)
+    }
+  }
+
+  test("test NULL inputs for BloomFilterMightContain") {
+    val table = "test"
+
+    withTable(table) {
+      sql(s"create table $table(col1 long, col2 int) using parquet")
+      sql(s"insert into $table values (201, 1), (null, 2)")
+      checkSparkAnswerAndOperator(
+        s"""
+           |SELECT might_contain(null, null) both_null,
+           |       might_contain(null, 1L) null_bf,
+           |       might_contain(
+           |         
X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267', 
col1) null_value
+           |       FROM $table
+           |""".stripMargin)
+    }
+  }
+
+  test("test BloomFilterMightContain from random input") {
+    val (longs, bfBytes) = bloomFilterFromRandomInput(10000, 10000)
+    val table = "test"
+
+    withTable(table) {
+      sql(s"create table $table(col1 long, col2 binary) using parquet")
+      spark.createDataset(longs).map(x => (x, bfBytes)).toDF("col1", 
"col2").write.insertInto(table)
+      val df = spark.table(table).select(new 
Column(BloomFilterMightContain(lit(bfBytes).expr, col("col1").expr)))
+      checkSparkAnswerAndOperator(df)
+      // check with scalar subquery
+      checkSparkAnswerAndOperator(
+        s"""
+          |SELECT might_contain((select first(col2) as col2 from $table), 
col1) FROM $table
+          |""".stripMargin)
+    }
+  }
+
+  private def bloomFilterFromRandomInput(expectedItems: Long, expectedBits: 
Long): (Seq[Long], Array[Byte]) = {
+    val bf = BloomFilter.create(expectedItems, expectedBits)
+    val longs = (0 until expectedItems.toInt).map(_ => Random.nextLong())
+    longs.foreach(bf.put)
+    val os = new ByteArrayOutputStream()
+    bf.writeTo(os)
+    (longs, os.toByteArray)
+  }
+}

Reply via email to