This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new e3ac6cf9 feat: Implement bloom_filter_agg (#987)
e3ac6cf9 is described below

commit e3ac6cf9b5f6e38f6a8f384e2f4e624ab75ed3b7
Author: Matt Butrovich <[email protected]>
AuthorDate: Fri Oct 18 16:43:36 2024 -0400

    feat: Implement bloom_filter_agg (#987)
    
    * Add test that invokes bloom_filter_agg.
    
    * QueryPlanSerde support for BloomFilterAgg.
    
    * Add bloom_filter_agg based on sample UDAF. planner instantiates it now. 
Added spark_bit_array_tests.
    
    * Partial work on Accumulator. Need to finish merge_batch and state.
    
    * BloomFilterAgg state, merge_state, and evaluate. Need more tests.
    
    * Matches Spark behavior. Need to clean up the code quite a bit, and do 
`cargo clippy`.
    
    * Remove old comment.
    
    * Clippy. Increase bloom filter size back to Spark's default.
    
    * API cleanup.
    
    * API cleanup.
    
    * Add BloomFilterAgg benchmark to CometExecBenchmark
    
    * Docs.
    
    * API cleanup, fix merge_bits to update cardinality.
    
    * Refactor merge_bits to update bit_count with the bit merging.
    
    * Remove benchmark results file.
    
    * Docs.
    
    * Add native side benchmarks.
    
    * Adjust benchmark parameters to match Spark defaults.
    
    * Address review feedback.
    
    * Add assertion to merge_batch.
    
    * Address some review feedback.
    
    * Only generate native BloomFilterAgg if child has LongType.
    
    * Add TODO with GitHub issue link.
---
 native/core/Cargo.toml                             |   4 +
 native/core/benches/bloom_filter_agg.rs            | 162 +++++++++++++++++++++
 .../datafusion/expressions/bloom_filter_agg.rs     | 151 +++++++++++++++++++
 .../expressions/bloom_filter_might_contain.rs      |   2 +-
 .../src/execution/datafusion/expressions/mod.rs    |   1 +
 native/core/src/execution/datafusion/planner.rs    |  17 +++
 .../execution/datafusion/util/spark_bit_array.rs   | 107 +++++++++++++-
 .../datafusion/util/spark_bloom_filter.rs          |  60 +++++++-
 native/proto/src/proto/expr.proto                  |   8 +
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  35 ++++-
 .../org/apache/comet/exec/CometExecSuite.scala     |  29 +++-
 .../spark/sql/benchmark/CometExecBenchmark.scala   |  81 +++++++++--
 12 files changed, 637 insertions(+), 20 deletions(-)

diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml
index 30dce574..daa0837c 100644
--- a/native/core/Cargo.toml
+++ b/native/core/Cargo.toml
@@ -126,3 +126,7 @@ harness = false
 [[bench]]
 name = "aggregate"
 harness = false
+
+[[bench]]
+name = "bloom_filter_agg"
+harness = false
diff --git a/native/core/benches/bloom_filter_agg.rs 
b/native/core/benches/bloom_filter_agg.rs
new file mode 100644
index 00000000..90e3e3f6
--- /dev/null
+++ b/native/core/benches/bloom_filter_agg.rs
@@ -0,0 +1,162 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.use arrow::array::{ArrayRef, BooleanBuilder, 
Int32Builder, RecordBatch, StringBuilder};
+
+use arrow::datatypes::{DataType, Field, Schema};
+use arrow_array::builder::Int64Builder;
+use arrow_array::{ArrayRef, RecordBatch};
+use arrow_schema::SchemaRef;
+use 
comet::execution::datafusion::expressions::bloom_filter_agg::BloomFilterAgg;
+use criterion::{black_box, criterion_group, criterion_main, Criterion};
+use datafusion::physical_expr::PhysicalExpr;
+use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, 
PhysicalGroupBy};
+use datafusion::physical_plan::memory::MemoryExec;
+use datafusion::physical_plan::ExecutionPlan;
+use datafusion_common::ScalarValue;
+use datafusion_execution::TaskContext;
+use datafusion_expr::AggregateUDF;
+use datafusion_physical_expr::aggregate::AggregateExprBuilder;
+use datafusion_physical_expr::expressions::{Column, Literal};
+use futures::StreamExt;
+use std::sync::Arc;
+use std::time::Duration;
+use tokio::runtime::Runtime;
+
+fn criterion_benchmark(c: &mut Criterion) {
+    let mut group = c.benchmark_group("bloom_filter_agg");
+    let num_rows = 8192;
+    let batch = create_record_batch(num_rows);
+    let mut batches = Vec::new();
+    for _ in 0..10 {
+        batches.push(batch.clone());
+    }
+    let partitions = &[batches];
+    let c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
+    // spark.sql.optimizer.runtime.bloomFilter.expectedNumItems
+    let num_items_sv = ScalarValue::Int64(Some(1000000_i64));
+    let num_items: Arc<dyn PhysicalExpr> = 
Arc::new(Literal::new(num_items_sv));
+    //spark.sql.optimizer.runtime.bloomFilter.numBits
+    let num_bits_sv = ScalarValue::Int64(Some(8388608_i64));
+    let num_bits: Arc<dyn PhysicalExpr> = Arc::new(Literal::new(num_bits_sv));
+
+    let rt = Runtime::new().unwrap();
+
+    for agg_mode in [
+        ("partial_agg", AggregateMode::Partial),
+        ("single_agg", AggregateMode::Single),
+    ] {
+        group.bench_function(agg_mode.0, |b| {
+            let comet_bloom_filter_agg =
+                Arc::new(AggregateUDF::new_from_impl(BloomFilterAgg::new(
+                    Arc::clone(&c0),
+                    Arc::clone(&num_items),
+                    Arc::clone(&num_bits),
+                    "bloom_filter_agg",
+                    DataType::Binary,
+                )));
+            b.to_async(&rt).iter(|| {
+                black_box(agg_test(
+                    partitions,
+                    c0.clone(),
+                    comet_bloom_filter_agg.clone(),
+                    "bloom_filter_agg",
+                    agg_mode.1,
+                ))
+            })
+        });
+    }
+
+    group.finish();
+}
+
+async fn agg_test(
+    partitions: &[Vec<RecordBatch>],
+    c0: Arc<dyn PhysicalExpr>,
+    aggregate_udf: Arc<AggregateUDF>,
+    alias: &str,
+    mode: AggregateMode,
+) {
+    let schema = &partitions[0][0].schema();
+    let scan: Arc<dyn ExecutionPlan> =
+        Arc::new(MemoryExec::try_new(partitions, Arc::clone(schema), 
None).unwrap());
+    let aggregate = create_aggregate(scan, c0.clone(), schema, aggregate_udf, 
alias, mode);
+    let mut stream = aggregate
+        .execute(0, Arc::new(TaskContext::default()))
+        .unwrap();
+    while let Some(batch) = stream.next().await {
+        let _batch = batch.unwrap();
+    }
+}
+
+fn create_aggregate(
+    scan: Arc<dyn ExecutionPlan>,
+    c0: Arc<dyn PhysicalExpr>,
+    schema: &SchemaRef,
+    aggregate_udf: Arc<AggregateUDF>,
+    alias: &str,
+    mode: AggregateMode,
+) -> Arc<AggregateExec> {
+    let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c0.clone()])
+        .schema(schema.clone())
+        .alias(alias)
+        .with_ignore_nulls(false)
+        .with_distinct(false)
+        .build()
+        .unwrap();
+
+    Arc::new(
+        AggregateExec::try_new(
+            mode,
+            PhysicalGroupBy::new_single(vec![]),
+            vec![aggr_expr],
+            vec![None],
+            scan,
+            Arc::clone(schema),
+        )
+        .unwrap(),
+    )
+}
+
+fn create_record_batch(num_rows: usize) -> RecordBatch {
+    let mut int64_builder = Int64Builder::with_capacity(num_rows);
+    for i in 0..num_rows {
+        int64_builder.append_value(i as i64);
+    }
+    let int64_array = Arc::new(int64_builder.finish());
+
+    let mut fields = vec![];
+    let mut columns: Vec<ArrayRef> = vec![];
+
+    // int64 column
+    fields.push(Field::new("c0", DataType::Int64, false));
+    columns.push(int64_array);
+
+    let schema = Schema::new(fields);
+    RecordBatch::try_new(Arc::new(schema), columns).unwrap()
+}
+
+fn config() -> Criterion {
+    Criterion::default()
+        .measurement_time(Duration::from_millis(500))
+        .warm_up_time(Duration::from_millis(500))
+}
+
+criterion_group! {
+    name = benches;
+    config = config();
+    targets = criterion_benchmark
+}
+criterion_main!(benches);
diff --git 
a/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs 
b/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
new file mode 100644
index 00000000..ed64b80e
--- /dev/null
+++ b/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
@@ -0,0 +1,151 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_schema::Field;
+use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
+use std::{any::Any, sync::Arc};
+
+use crate::execution::datafusion::util::spark_bloom_filter;
+use crate::execution::datafusion::util::spark_bloom_filter::SparkBloomFilter;
+use arrow::array::ArrayRef;
+use arrow_array::BinaryArray;
+use datafusion::error::Result;
+use datafusion::physical_expr::PhysicalExpr;
+use datafusion_common::{downcast_value, DataFusionError, ScalarValue};
+use datafusion_expr::{
+    function::{AccumulatorArgs, StateFieldsArgs},
+    Accumulator, AggregateUDFImpl, Signature,
+};
+use datafusion_physical_expr::expressions::Literal;
+
+#[derive(Debug, Clone)]
+pub struct BloomFilterAgg {
+    name: String,
+    signature: Signature,
+    expr: Arc<dyn PhysicalExpr>,
+    num_items: i32,
+    num_bits: i32,
+}
+
+#[inline]
+fn extract_i32_from_literal(expr: Arc<dyn PhysicalExpr>) -> i32 {
+    match expr.as_any().downcast_ref::<Literal>().unwrap().value() {
+        ScalarValue::Int64(scalar_value) => scalar_value.unwrap() as i32,
+        _ => {
+            unreachable!()
+        }
+    }
+}
+
+impl BloomFilterAgg {
+    pub fn new(
+        expr: Arc<dyn PhysicalExpr>,
+        num_items: Arc<dyn PhysicalExpr>,
+        num_bits: Arc<dyn PhysicalExpr>,
+        name: impl Into<String>,
+        data_type: DataType,
+    ) -> Self {
+        assert!(matches!(data_type, DataType::Binary));
+        Self {
+            name: name.into(),
+            signature: Signature::exact(vec![DataType::Int64], 
Volatility::Immutable),
+            expr,
+            num_items: extract_i32_from_literal(num_items),
+            num_bits: extract_i32_from_literal(num_bits),
+        }
+    }
+}
+
+impl AggregateUDFImpl for BloomFilterAgg {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "bloom_filter_agg"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Binary)
+    }
+
+    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+        Ok(Box::new(SparkBloomFilter::from((
+            spark_bloom_filter::optimal_num_hash_functions(self.num_items, 
self.num_bits),
+            self.num_bits,
+        ))))
+    }
+
+    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
+        Ok(vec![Field::new("bits", DataType::Binary, false)])
+    }
+
+    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
+        false
+    }
+}
+
+impl Accumulator for SparkBloomFilter {
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        if values.is_empty() {
+            return Ok(());
+        }
+        let arr = &values[0];
+        (0..arr.len()).try_for_each(|index| {
+            let v = ScalarValue::try_from_array(arr, index)?;
+
+            if let ScalarValue::Int64(Some(value)) = v {
+                self.put_long(value);
+            } else {
+                unreachable!()
+            }
+            Ok(())
+        })
+    }
+
+    fn evaluate(&mut self) -> Result<ScalarValue> {
+        Ok(ScalarValue::Binary(Some(self.spark_serialization())))
+    }
+
+    fn size(&self) -> usize {
+        std::mem::size_of_val(self)
+    }
+
+    fn state(&mut self) -> Result<Vec<ScalarValue>> {
+        // There might be a more efficient way to do this by transmuting since 
calling state() on an
+        // Accumulator is considered destructive.
+        let state_sv = ScalarValue::Binary(Some(self.state_as_bytes()));
+        Ok(vec![state_sv])
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        assert_eq!(
+            states.len(),
+            1,
+            "Expect one element in 'states' but found {}",
+            states.len()
+        );
+        assert_eq!(states[0].len(), 1);
+        let state_sv = downcast_value!(states[0], BinaryArray);
+        self.merge_filter(state_sv.value_data());
+        Ok(())
+    }
+}
diff --git 
a/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs
 
b/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs
index 462a2224..de922d83 100644
--- 
a/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs
+++ 
b/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs
@@ -72,7 +72,7 @@ fn evaluate_bloom_filter(
     let bloom_filter_bytes = bloom_filter_expr.evaluate(&batch)?;
     match bloom_filter_bytes {
         ColumnarValue::Scalar(ScalarValue::Binary(v)) => {
-            Ok(v.map(|v| SparkBloomFilter::new(v.as_bytes())))
+            Ok(v.map(|v| SparkBloomFilter::from(v.as_bytes())))
         }
         _ => internal_err!("Bloom filter expression should be evaluated as a 
scalar binary value"),
     }
diff --git a/native/core/src/execution/datafusion/expressions/mod.rs 
b/native/core/src/execution/datafusion/expressions/mod.rs
index 10c9d309..48b80384 100644
--- a/native/core/src/execution/datafusion/expressions/mod.rs
+++ b/native/core/src/execution/datafusion/expressions/mod.rs
@@ -25,6 +25,7 @@ pub use normalize_nan::NormalizeNaNAndZero;
 use crate::errors::CometError;
 pub mod avg;
 pub mod avg_decimal;
+pub mod bloom_filter_agg;
 pub mod bloom_filter_might_contain;
 pub mod comet_scalar_funcs;
 pub mod correlation;
diff --git a/native/core/src/execution/datafusion/planner.rs 
b/native/core/src/execution/datafusion/planner.rs
index d63fd707..5b53cb39 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -28,6 +28,7 @@ use crate::{
                 avg::Avg,
                 avg_decimal::AvgDecimal,
                 bitwise_not::BitwiseNotExpr,
+                bloom_filter_agg::BloomFilterAgg,
                 bloom_filter_might_contain::BloomFilterMightContain,
                 checkoverflow::CheckOverflow,
                 correlation::Correlation,
@@ -1620,6 +1621,22 @@ impl PhysicalPlanner {
                 ));
                 Self::create_aggr_func_expr("correlation", schema, 
vec![child1, child2], func)
             }
+            AggExprStruct::BloomFilterAgg(expr) => {
+                let child = self.create_expr(expr.child.as_ref().unwrap(), 
Arc::clone(&schema))?;
+                let num_items =
+                    self.create_expr(expr.num_items.as_ref().unwrap(), 
Arc::clone(&schema))?;
+                let num_bits =
+                    self.create_expr(expr.num_bits.as_ref().unwrap(), 
Arc::clone(&schema))?;
+                let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+                let func = AggregateUDF::new_from_impl(BloomFilterAgg::new(
+                    Arc::clone(&child),
+                    Arc::clone(&num_items),
+                    Arc::clone(&num_bits),
+                    "bloom_filter_agg",
+                    datatype,
+                ));
+                Self::create_aggr_func_expr("bloom_filter_agg", schema, 
vec![child], func)
+            }
         }
     }
 
diff --git a/native/core/src/execution/datafusion/util/spark_bit_array.rs 
b/native/core/src/execution/datafusion/util/spark_bit_array.rs
index 9729627d..68b97d66 100644
--- a/native/core/src/execution/datafusion/util/spark_bit_array.rs
+++ b/native/core/src/execution/datafusion/util/spark_bit_array.rs
@@ -15,6 +15,10 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::common::bit;
+use arrow_buffer::ToByteSlice;
+use std::iter::zip;
+
 /// A simple bit array implementation that simulates the behavior of Spark's 
BitArray which is
 /// used in the BloomFilter implementation. Some methods are not implemented 
as they are not
 /// required for the current use case.
@@ -55,12 +59,50 @@ impl SparkBitArray {
     }
 
     pub fn bit_size(&self) -> u64 {
-        self.data.len() as u64 * 64
+        self.word_size() as u64 * 64
+    }
+
+    pub fn byte_size(&self) -> usize {
+        self.word_size() * 8
+    }
+
+    pub fn word_size(&self) -> usize {
+        self.data.len()
     }
 
     pub fn cardinality(&self) -> usize {
         self.bit_count
     }
+
+    pub fn to_bytes(&self) -> Vec<u8> {
+        Vec::from(self.data.to_byte_slice())
+    }
+
+    pub fn data(&self) -> Vec<u64> {
+        self.data.clone()
+    }
+
+    // Combines SparkBitArrays, however other is a &[u8] because we anticipate 
to come from an
+    // Arrow ScalarValue::Binary which is a byte vector underneath, rather 
than a word vector.
+    pub fn merge_bits(&mut self, other: &[u8]) {
+        assert_eq!(self.byte_size(), other.len());
+        let mut bit_count: usize = 0;
+        // For each word, merge the bits into self, and accumulate a new 
bit_count.
+        for i in zip(
+            self.data.iter_mut(),
+            other
+                .chunks(8)
+                .map(|chunk| u64::from_ne_bytes(chunk.try_into().unwrap())),
+        ) {
+            *i.0 |= i.1;
+            bit_count += i.0.count_ones() as usize;
+        }
+        self.bit_count = bit_count;
+    }
+}
+
+pub fn num_words(num_bits: i32) -> i32 {
+    bit::ceil(num_bits as usize, 64) as i32
 }
 
 #[cfg(test)]
@@ -128,4 +170,67 @@ mod test {
         // check cardinality
         assert_eq!(array.cardinality(), 6);
     }
+
+    #[test]
+    fn test_spark_bit_with_empty_buffer() {
+        let buf = vec![0u64; 4];
+        let array = SparkBitArray::new(buf);
+
+        assert_eq!(array.bit_size(), 256);
+        assert_eq!(array.cardinality(), 0);
+
+        for n in 0..256 {
+            assert!(!array.get(n));
+        }
+    }
+
+    #[test]
+    fn test_spark_bit_with_full_buffer() {
+        let buf = vec![u64::MAX; 4];
+        let array = SparkBitArray::new(buf);
+
+        assert_eq!(array.bit_size(), 256);
+        assert_eq!(array.cardinality(), 256);
+
+        for n in 0..256 {
+            assert!(array.get(n));
+        }
+    }
+
+    #[test]
+    fn test_spark_bit_merge() {
+        let buf1 = vec![0u64; 4];
+        let mut array1 = SparkBitArray::new(buf1);
+        let buf2 = vec![0u64; 4];
+        let mut array2 = SparkBitArray::new(buf2);
+
+        let primes = [
+            2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 
61, 67, 71, 73, 79, 83,
+            89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 
157, 163, 167, 173, 179,
+            181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251,
+        ];
+        let fibs = [1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233];
+
+        for n in fibs {
+            array1.set(n);
+        }
+
+        for n in primes {
+            array2.set(n);
+        }
+
+        assert_eq!(array1.cardinality(), fibs.len());
+        assert_eq!(array2.cardinality(), primes.len());
+
+        array1.merge_bits(array2.to_bytes().as_slice());
+
+        for n in fibs {
+            assert!(array1.get(n));
+        }
+
+        for n in primes {
+            assert!(array1.get(n));
+        }
+        assert_eq!(array1.cardinality(), 60);
+    }
 }
diff --git a/native/core/src/execution/datafusion/util/spark_bloom_filter.rs 
b/native/core/src/execution/datafusion/util/spark_bloom_filter.rs
index 00f71767..22a84d85 100644
--- a/native/core/src/execution/datafusion/util/spark_bloom_filter.rs
+++ b/native/core/src/execution/datafusion/util/spark_bloom_filter.rs
@@ -15,9 +15,12 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::execution::datafusion::util::spark_bit_array;
 use crate::execution::datafusion::util::spark_bit_array::SparkBitArray;
 use arrow_array::{ArrowNativeTypeOp, BooleanArray, Int64Array};
+use arrow_buffer::ToByteSlice;
 use datafusion_comet_spark_expr::spark_hash::spark_compatible_murmur3_hash;
+use std::cmp;
 
 const SPARK_BLOOM_FILTER_VERSION_1: i32 = 1;
 
@@ -30,8 +33,29 @@ pub struct SparkBloomFilter {
     num_hash_functions: u32,
 }
 
-impl SparkBloomFilter {
-    pub fn new(buf: &[u8]) -> Self {
+pub fn optimal_num_hash_functions(expected_items: i32, num_bits: i32) -> i32 {
+    cmp::max(
+        1,
+        ((num_bits as f64 / expected_items as f64) * 2.0_f64.ln()).round() as 
i32,
+    )
+}
+
+impl From<(i32, i32)> for SparkBloomFilter {
+    /// Creates an empty SparkBloomFilter given number of hash functions and 
bits.
+    fn from((num_hash_functions, num_bits): (i32, i32)) -> Self {
+        let num_words = spark_bit_array::num_words(num_bits);
+        let bits = vec![0u64; num_words as usize];
+        Self {
+            bits: SparkBitArray::new(bits),
+            num_hash_functions: num_hash_functions as u32,
+        }
+    }
+}
+
+impl From<&[u8]> for SparkBloomFilter {
+    /// Creates a SparkBloomFilter from a serialized byte array conforming to 
Spark's BloomFilter
+    /// binary format version 1.
+    fn from(buf: &[u8]) -> Self {
         let mut offset = 0;
         let version = read_num_be_bytes!(i32, 4, buf[offset..]);
         offset += 4;
@@ -54,6 +78,25 @@ impl SparkBloomFilter {
             num_hash_functions: num_hash_functions as u32,
         }
     }
+}
+
+impl SparkBloomFilter {
+    /// Serializes a SparkBloomFilter to a byte array conforming to Spark's 
BloomFilter
+    /// binary format version 1.
+    pub fn spark_serialization(&self) -> Vec<u8> {
+        // There might be a more efficient way to do this, even with all the 
endianness stuff.
+        let mut spark_bloom_filter: Vec<u8> = 1_u32.to_be_bytes().to_vec();
+        spark_bloom_filter.append(&mut 
self.num_hash_functions.to_be_bytes().to_vec());
+        spark_bloom_filter.append(&mut (self.bits.word_size() as 
u32).to_be_bytes().to_vec());
+        let mut filter_state: Vec<u64> = self.bits.data();
+        for i in filter_state.iter_mut() {
+            *i = i.to_be();
+        }
+        // Does it make sense to do a std::mem::take of filter_state here? 
Unclear to me if a deep
+        // copy of filter_state as a Vec<u64> to a Vec<u8> is happening here.
+        spark_bloom_filter.append(&mut 
Vec::from(filter_state.to_byte_slice()));
+        spark_bloom_filter
+    }
 
     pub fn put_long(&mut self, item: i64) -> bool {
         // Here we first hash the input long element into 2 int hash values, 
h1 and h2, then produce
@@ -94,4 +137,17 @@ impl SparkBloomFilter {
             .map(|v| v.map(|x| self.might_contain_long(x)))
             .collect()
     }
+
+    pub fn state_as_bytes(&self) -> Vec<u8> {
+        self.bits.to_bytes()
+    }
+
+    pub fn merge_filter(&mut self, other: &[u8]) {
+        assert_eq!(
+            other.len(),
+            self.bits.byte_size(),
+            "Cannot merge SparkBloomFilters with different lengths."
+        );
+        self.bits.merge_bits(other);
+    }
 }
diff --git a/native/proto/src/proto/expr.proto 
b/native/proto/src/proto/expr.proto
index 1a3e3c9f..796ca5be 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -101,6 +101,7 @@ message AggExpr {
     Variance variance = 13;
     Stddev stddev = 14;
     Correlation correlation = 15;
+    BloomFilterAgg bloomFilterAgg = 16;
   }
 }
 
@@ -192,6 +193,13 @@ message Correlation {
   DataType datatype = 4;
 }
 
+message BloomFilterAgg {
+  Expr child = 1;
+  Expr numItems = 2;
+  Expr numBits = 3;
+  DataType datatype = 4;
+}
+
 message Literal {
   oneof value {
     bool bool_val = 1;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index c6e692cc..3805d418 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, Complete, Corr, Count, CovPopulation, 
CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, 
VariancePop, VarianceSamp}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, Complete, Corr, 
Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, 
StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, 
NormalizeNaNAndZero}
 import org.apache.spark.sql.catalyst.plans._
@@ -760,6 +760,39 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           withInfo(aggExpr, child1, child2)
           None
         }
+
+      case bloom_filter @ BloomFilterAggregate(child, numItems, numBits, _, _) 
=>
+        // We ignore mutableAggBufferOffset and inputAggBufferOffset because 
they are
+        // implementation details for Spark's ObjectHashAggregate.
+        val childExpr = exprToProto(child, inputs, binding)
+        val numItemsExpr = exprToProto(numItems, inputs, binding)
+        val numBitsExpr = exprToProto(numBits, inputs, binding)
+        val dataType = serializeDataType(bloom_filter.dataType)
+
+        // TODO: Support more types
+        //  https://github.com/apache/datafusion-comet/issues/1023
+        if (childExpr.isDefined &&
+          child.dataType
+            .isInstanceOf[LongType] &&
+          numItemsExpr.isDefined &&
+          numBitsExpr.isDefined &&
+          dataType.isDefined) {
+          val bloomFilterAggBuilder = 
ExprOuterClass.BloomFilterAgg.newBuilder()
+          bloomFilterAggBuilder.setChild(childExpr.get)
+          bloomFilterAggBuilder.setNumItems(numItemsExpr.get)
+          bloomFilterAggBuilder.setNumBits(numBitsExpr.get)
+          bloomFilterAggBuilder.setDatatype(dataType.get)
+
+          Some(
+            ExprOuterClass.AggExpr
+              .newBuilder()
+              .setBloomFilterAgg(bloomFilterAggBuilder)
+              .build())
+        } else {
+          withInfo(aggExpr, child, numItems, numBits)
+          None
+        }
+
       case fn =>
         val msg = s"unsupported Spark aggregate function: ${fn.prettyName}"
         emitWarning(msg)
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 05aa2372..78f59cbe 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -31,10 +31,10 @@ import org.scalatest.Tag
 
 import org.apache.hadoop.fs.Path
 import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, 
DataFrame, DataFrameWriter, Row, SaveMode}
-import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
 import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, 
CatalogTable}
-import org.apache.spark.sql.catalyst.expressions.Hex
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, 
Hex}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, 
BloomFilterAggregate}
 import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec, 
CometHashAggregateExec, CometHashJoinExec, CometProjectExec, CometScanExec, 
CometSortExec, CometSortMergeJoinExec, CometSparkToColumnarExec, 
CometTakeOrderedAndProjectExec}
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometShuffleExchangeExec}
 import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, 
SQLExecution, UnionExec}
@@ -911,6 +911,29 @@ class CometExecSuite extends CometTestBase {
     }
   }
 
+  test("bloom_filter_agg") {
+    val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg")
+    spark.sessionState.functionRegistry.registerFunction(
+      funcId_bloom_filter_agg,
+      new ExpressionInfo(classOf[BloomFilterAggregate].getName, 
"bloom_filter_agg"),
+      (children: Seq[Expression]) =>
+        children.size match {
+          case 1 => new BloomFilterAggregate(children.head)
+          case 2 => new BloomFilterAggregate(children.head, children(1))
+          case 3 => new BloomFilterAggregate(children.head, children(1), 
children(2))
+        })
+
+    withParquetTable(
+      (0 until 100)
+        .map(_ => (Random.nextInt(), Random.nextInt() % 5)),
+      "tbl") {
+      val df = sql("SELECT bloom_filter_agg(cast(_2 as long)) FROM tbl")
+      checkSparkAnswerAndOperator(df)
+    }
+
+    spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)
+  }
+
   test("sort (non-global)") {
     withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") {
       val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc)
diff --git 
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala 
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala
index 7205484e..3dd930f6 100644
--- 
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala
+++ 
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala
@@ -22,6 +22,9 @@ package org.apache.spark.sql.benchmark
 import org.apache.spark.SparkConf
 import org.apache.spark.benchmark.Benchmark
 import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
+import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
 import org.apache.spark.sql.internal.SQLConf
 
 import org.apache.comet.{CometConf, CometSparkSessionExtensions}
@@ -222,23 +225,77 @@ object CometExecBenchmark extends CometBenchmarkBase {
     }
   }
 
-  override def runCometBenchmark(mainArgs: Array[String]): Unit = {
-    runBenchmarkWithTable("Subquery", 1024 * 1024 * 10) { v =>
-      subqueryExecBenchmark(v)
-    }
+  // BloomFilterAgg takes an argument for the expected number of distinct 
values, which determines filter size and
+  // number of hash functions. We use the cardinality as a hint to the 
aggregate, otherwise the default Spark values
+  // make a big filter with a lot of hash functions.
+  def bloomFilterAggregate(values: Int, cardinality: Int): Unit = {
+    val benchmark =
+      new Benchmark(
+        s"BloomFilterAggregate Exec (cardinality $cardinality)",
+        values,
+        output = output)
 
-    runBenchmarkWithTable("Expand", 1024 * 1024 * 10) { v =>
-      expandExecBenchmark(v)
-    }
+    val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg")
+    spark.sessionState.functionRegistry.registerFunction(
+      funcId_bloom_filter_agg,
+      new ExpressionInfo(classOf[BloomFilterAggregate].getName, 
"bloom_filter_agg"),
+      (children: Seq[Expression]) => new BloomFilterAggregate(children.head, 
children(1)))
+
+    withTempPath { dir =>
+      withTempTable("parquetV1Table") {
+        prepareTable(dir, spark.sql(s"SELECT floor(rand() * $cardinality) as 
key FROM $tbl"))
 
-    runBenchmarkWithTable("Project + Filter", 1024 * 1024 * 10) { v =>
-      for (fractionOfZeros <- List(0.0, 0.50, 0.95)) {
-        numericFilterExecBenchmark(v, fractionOfZeros)
+        val query =
+          s"SELECT bloom_filter_agg(cast(key as long), cast($cardinality as 
long)) FROM parquetV1Table"
+
+        benchmark.addCase("SQL Parquet - Spark (BloomFilterAgg)") { _ =>
+          spark.sql(query).noop()
+        }
+
+        benchmark.addCase("SQL Parquet - Comet (Scan) (BloomFilterAgg)") { _ =>
+          withSQLConf(CometConf.COMET_ENABLED.key -> "true") {
+            spark.sql(query).noop()
+          }
+        }
+
+        benchmark.addCase("SQL Parquet - Comet (Scan, Exec) (BloomFilterAgg)") 
{ _ =>
+          withSQLConf(
+            CometConf.COMET_ENABLED.key -> "true",
+            CometConf.COMET_EXEC_ENABLED.key -> "true") {
+            spark.sql(query).noop()
+          }
+        }
+
+        benchmark.run()
       }
     }
 
-    runBenchmarkWithTable("Sort", 1024 * 1024 * 10) { v =>
-      sortExecBenchmark(v)
+    spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)
+  }
+
+  override def runCometBenchmark(mainArgs: Array[String]): Unit = {
+//    runBenchmarkWithTable("Subquery", 1024 * 1024 * 10) { v =>
+//      subqueryExecBenchmark(v)
+//    }
+//
+//    runBenchmarkWithTable("Expand", 1024 * 1024 * 10) { v =>
+//      expandExecBenchmark(v)
+//    }
+//
+//    runBenchmarkWithTable("Project + Filter", 1024 * 1024 * 10) { v =>
+//      for (fractionOfZeros <- List(0.0, 0.50, 0.95)) {
+//        numericFilterExecBenchmark(v, fractionOfZeros)
+//      }
+//    }
+//
+//    runBenchmarkWithTable("Sort", 1024 * 1024 * 10) { v =>
+//      sortExecBenchmark(v)
+//    }
+
+    runBenchmarkWithTable("BloomFilterAggregate", 1024 * 1024 * 10) { v =>
+      for (card <- List(100, 1024, 1024 * 1024)) {
+        bloomFilterAggregate(v, card)
+      }
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to