This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new e3ac6cf9 feat: Implement bloom_filter_agg (#987)
e3ac6cf9 is described below
commit e3ac6cf9b5f6e38f6a8f384e2f4e624ab75ed3b7
Author: Matt Butrovich <[email protected]>
AuthorDate: Fri Oct 18 16:43:36 2024 -0400
feat: Implement bloom_filter_agg (#987)
* Add test that invokes bloom_filter_agg.
* QueryPlanSerde support for BloomFilterAgg.
* Add bloom_filter_agg based on sample UDAF. planner instantiates it now.
Added spark_bit_array_tests.
* Partial work on Accumulator. Need to finish merge_batch and state.
* BloomFilterAgg state, merge_state, and evaluate. Need more tests.
* Matches Spark behavior. Need to clean up the code quite a bit, and do
`cargo clippy`.
* Remove old comment.
* Clippy. Increase bloom filter size back to Spark's default.
* API cleanup.
* API cleanup.
* Add BloomFilterAgg benchmark to CometExecBenchmark
* Docs.
* API cleanup, fix merge_bits to update cardinality.
* Refactor merge_bits to update bit_count with the bit merging.
* Remove benchmark results file.
* Docs.
* Add native side benchmarks.
* Adjust benchmark parameters to match Spark defaults.
* Address review feedback.
* Add assertion to merge_batch.
* Address some review feedback.
* Only generate native BloomFilterAgg if child has LongType.
* Add TODO with GitHub issue link.
---
native/core/Cargo.toml | 4 +
native/core/benches/bloom_filter_agg.rs | 162 +++++++++++++++++++++
.../datafusion/expressions/bloom_filter_agg.rs | 151 +++++++++++++++++++
.../expressions/bloom_filter_might_contain.rs | 2 +-
.../src/execution/datafusion/expressions/mod.rs | 1 +
native/core/src/execution/datafusion/planner.rs | 17 +++
.../execution/datafusion/util/spark_bit_array.rs | 107 +++++++++++++-
.../datafusion/util/spark_bloom_filter.rs | 60 +++++++-
native/proto/src/proto/expr.proto | 8 +
.../org/apache/comet/serde/QueryPlanSerde.scala | 35 ++++-
.../org/apache/comet/exec/CometExecSuite.scala | 29 +++-
.../spark/sql/benchmark/CometExecBenchmark.scala | 81 +++++++++--
12 files changed, 637 insertions(+), 20 deletions(-)
diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml
index 30dce574..daa0837c 100644
--- a/native/core/Cargo.toml
+++ b/native/core/Cargo.toml
@@ -126,3 +126,7 @@ harness = false
[[bench]]
name = "aggregate"
harness = false
+
+[[bench]]
+name = "bloom_filter_agg"
+harness = false
diff --git a/native/core/benches/bloom_filter_agg.rs
b/native/core/benches/bloom_filter_agg.rs
new file mode 100644
index 00000000..90e3e3f6
--- /dev/null
+++ b/native/core/benches/bloom_filter_agg.rs
@@ -0,0 +1,162 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.use arrow::array::{ArrayRef, BooleanBuilder,
Int32Builder, RecordBatch, StringBuilder};
+
+use arrow::datatypes::{DataType, Field, Schema};
+use arrow_array::builder::Int64Builder;
+use arrow_array::{ArrayRef, RecordBatch};
+use arrow_schema::SchemaRef;
+use
comet::execution::datafusion::expressions::bloom_filter_agg::BloomFilterAgg;
+use criterion::{black_box, criterion_group, criterion_main, Criterion};
+use datafusion::physical_expr::PhysicalExpr;
+use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode,
PhysicalGroupBy};
+use datafusion::physical_plan::memory::MemoryExec;
+use datafusion::physical_plan::ExecutionPlan;
+use datafusion_common::ScalarValue;
+use datafusion_execution::TaskContext;
+use datafusion_expr::AggregateUDF;
+use datafusion_physical_expr::aggregate::AggregateExprBuilder;
+use datafusion_physical_expr::expressions::{Column, Literal};
+use futures::StreamExt;
+use std::sync::Arc;
+use std::time::Duration;
+use tokio::runtime::Runtime;
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let mut group = c.benchmark_group("bloom_filter_agg");
+ let num_rows = 8192;
+ let batch = create_record_batch(num_rows);
+ let mut batches = Vec::new();
+ for _ in 0..10 {
+ batches.push(batch.clone());
+ }
+ let partitions = &[batches];
+ let c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
+ // spark.sql.optimizer.runtime.bloomFilter.expectedNumItems
+ let num_items_sv = ScalarValue::Int64(Some(1000000_i64));
+ let num_items: Arc<dyn PhysicalExpr> =
Arc::new(Literal::new(num_items_sv));
+ //spark.sql.optimizer.runtime.bloomFilter.numBits
+ let num_bits_sv = ScalarValue::Int64(Some(8388608_i64));
+ let num_bits: Arc<dyn PhysicalExpr> = Arc::new(Literal::new(num_bits_sv));
+
+ let rt = Runtime::new().unwrap();
+
+ for agg_mode in [
+ ("partial_agg", AggregateMode::Partial),
+ ("single_agg", AggregateMode::Single),
+ ] {
+ group.bench_function(agg_mode.0, |b| {
+ let comet_bloom_filter_agg =
+ Arc::new(AggregateUDF::new_from_impl(BloomFilterAgg::new(
+ Arc::clone(&c0),
+ Arc::clone(&num_items),
+ Arc::clone(&num_bits),
+ "bloom_filter_agg",
+ DataType::Binary,
+ )));
+ b.to_async(&rt).iter(|| {
+ black_box(agg_test(
+ partitions,
+ c0.clone(),
+ comet_bloom_filter_agg.clone(),
+ "bloom_filter_agg",
+ agg_mode.1,
+ ))
+ })
+ });
+ }
+
+ group.finish();
+}
+
+async fn agg_test(
+ partitions: &[Vec<RecordBatch>],
+ c0: Arc<dyn PhysicalExpr>,
+ aggregate_udf: Arc<AggregateUDF>,
+ alias: &str,
+ mode: AggregateMode,
+) {
+ let schema = &partitions[0][0].schema();
+ let scan: Arc<dyn ExecutionPlan> =
+ Arc::new(MemoryExec::try_new(partitions, Arc::clone(schema),
None).unwrap());
+ let aggregate = create_aggregate(scan, c0.clone(), schema, aggregate_udf,
alias, mode);
+ let mut stream = aggregate
+ .execute(0, Arc::new(TaskContext::default()))
+ .unwrap();
+ while let Some(batch) = stream.next().await {
+ let _batch = batch.unwrap();
+ }
+}
+
+fn create_aggregate(
+ scan: Arc<dyn ExecutionPlan>,
+ c0: Arc<dyn PhysicalExpr>,
+ schema: &SchemaRef,
+ aggregate_udf: Arc<AggregateUDF>,
+ alias: &str,
+ mode: AggregateMode,
+) -> Arc<AggregateExec> {
+ let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c0.clone()])
+ .schema(schema.clone())
+ .alias(alias)
+ .with_ignore_nulls(false)
+ .with_distinct(false)
+ .build()
+ .unwrap();
+
+ Arc::new(
+ AggregateExec::try_new(
+ mode,
+ PhysicalGroupBy::new_single(vec![]),
+ vec![aggr_expr],
+ vec![None],
+ scan,
+ Arc::clone(schema),
+ )
+ .unwrap(),
+ )
+}
+
+fn create_record_batch(num_rows: usize) -> RecordBatch {
+ let mut int64_builder = Int64Builder::with_capacity(num_rows);
+ for i in 0..num_rows {
+ int64_builder.append_value(i as i64);
+ }
+ let int64_array = Arc::new(int64_builder.finish());
+
+ let mut fields = vec![];
+ let mut columns: Vec<ArrayRef> = vec![];
+
+ // int64 column
+ fields.push(Field::new("c0", DataType::Int64, false));
+ columns.push(int64_array);
+
+ let schema = Schema::new(fields);
+ RecordBatch::try_new(Arc::new(schema), columns).unwrap()
+}
+
+fn config() -> Criterion {
+ Criterion::default()
+ .measurement_time(Duration::from_millis(500))
+ .warm_up_time(Duration::from_millis(500))
+}
+
+criterion_group! {
+ name = benches;
+ config = config();
+ targets = criterion_benchmark
+}
+criterion_main!(benches);
diff --git
a/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
b/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
new file mode 100644
index 00000000..ed64b80e
--- /dev/null
+++ b/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
@@ -0,0 +1,151 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_schema::Field;
+use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
+use std::{any::Any, sync::Arc};
+
+use crate::execution::datafusion::util::spark_bloom_filter;
+use crate::execution::datafusion::util::spark_bloom_filter::SparkBloomFilter;
+use arrow::array::ArrayRef;
+use arrow_array::BinaryArray;
+use datafusion::error::Result;
+use datafusion::physical_expr::PhysicalExpr;
+use datafusion_common::{downcast_value, DataFusionError, ScalarValue};
+use datafusion_expr::{
+ function::{AccumulatorArgs, StateFieldsArgs},
+ Accumulator, AggregateUDFImpl, Signature,
+};
+use datafusion_physical_expr::expressions::Literal;
+
+#[derive(Debug, Clone)]
+pub struct BloomFilterAgg {
+ name: String,
+ signature: Signature,
+ expr: Arc<dyn PhysicalExpr>,
+ num_items: i32,
+ num_bits: i32,
+}
+
+#[inline]
+fn extract_i32_from_literal(expr: Arc<dyn PhysicalExpr>) -> i32 {
+ match expr.as_any().downcast_ref::<Literal>().unwrap().value() {
+ ScalarValue::Int64(scalar_value) => scalar_value.unwrap() as i32,
+ _ => {
+ unreachable!()
+ }
+ }
+}
+
+impl BloomFilterAgg {
+ pub fn new(
+ expr: Arc<dyn PhysicalExpr>,
+ num_items: Arc<dyn PhysicalExpr>,
+ num_bits: Arc<dyn PhysicalExpr>,
+ name: impl Into<String>,
+ data_type: DataType,
+ ) -> Self {
+ assert!(matches!(data_type, DataType::Binary));
+ Self {
+ name: name.into(),
+ signature: Signature::exact(vec![DataType::Int64],
Volatility::Immutable),
+ expr,
+ num_items: extract_i32_from_literal(num_items),
+ num_bits: extract_i32_from_literal(num_bits),
+ }
+ }
+}
+
+impl AggregateUDFImpl for BloomFilterAgg {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "bloom_filter_agg"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ Ok(DataType::Binary)
+ }
+
+ fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ Ok(Box::new(SparkBloomFilter::from((
+ spark_bloom_filter::optimal_num_hash_functions(self.num_items,
self.num_bits),
+ self.num_bits,
+ ))))
+ }
+
+ fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
+ Ok(vec![Field::new("bits", DataType::Binary, false)])
+ }
+
+ fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
+ false
+ }
+}
+
+impl Accumulator for SparkBloomFilter {
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ if values.is_empty() {
+ return Ok(());
+ }
+ let arr = &values[0];
+ (0..arr.len()).try_for_each(|index| {
+ let v = ScalarValue::try_from_array(arr, index)?;
+
+ if let ScalarValue::Int64(Some(value)) = v {
+ self.put_long(value);
+ } else {
+ unreachable!()
+ }
+ Ok(())
+ })
+ }
+
+ fn evaluate(&mut self) -> Result<ScalarValue> {
+ Ok(ScalarValue::Binary(Some(self.spark_serialization())))
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
+ // There might be a more efficient way to do this by transmuting since
calling state() on an
+ // Accumulator is considered destructive.
+ let state_sv = ScalarValue::Binary(Some(self.state_as_bytes()));
+ Ok(vec![state_sv])
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ assert_eq!(
+ states.len(),
+ 1,
+ "Expect one element in 'states' but found {}",
+ states.len()
+ );
+ assert_eq!(states[0].len(), 1);
+ let state_sv = downcast_value!(states[0], BinaryArray);
+ self.merge_filter(state_sv.value_data());
+ Ok(())
+ }
+}
diff --git
a/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs
b/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs
index 462a2224..de922d83 100644
---
a/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs
+++
b/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs
@@ -72,7 +72,7 @@ fn evaluate_bloom_filter(
let bloom_filter_bytes = bloom_filter_expr.evaluate(&batch)?;
match bloom_filter_bytes {
ColumnarValue::Scalar(ScalarValue::Binary(v)) => {
- Ok(v.map(|v| SparkBloomFilter::new(v.as_bytes())))
+ Ok(v.map(|v| SparkBloomFilter::from(v.as_bytes())))
}
_ => internal_err!("Bloom filter expression should be evaluated as a
scalar binary value"),
}
diff --git a/native/core/src/execution/datafusion/expressions/mod.rs
b/native/core/src/execution/datafusion/expressions/mod.rs
index 10c9d309..48b80384 100644
--- a/native/core/src/execution/datafusion/expressions/mod.rs
+++ b/native/core/src/execution/datafusion/expressions/mod.rs
@@ -25,6 +25,7 @@ pub use normalize_nan::NormalizeNaNAndZero;
use crate::errors::CometError;
pub mod avg;
pub mod avg_decimal;
+pub mod bloom_filter_agg;
pub mod bloom_filter_might_contain;
pub mod comet_scalar_funcs;
pub mod correlation;
diff --git a/native/core/src/execution/datafusion/planner.rs
b/native/core/src/execution/datafusion/planner.rs
index d63fd707..5b53cb39 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -28,6 +28,7 @@ use crate::{
avg::Avg,
avg_decimal::AvgDecimal,
bitwise_not::BitwiseNotExpr,
+ bloom_filter_agg::BloomFilterAgg,
bloom_filter_might_contain::BloomFilterMightContain,
checkoverflow::CheckOverflow,
correlation::Correlation,
@@ -1620,6 +1621,22 @@ impl PhysicalPlanner {
));
Self::create_aggr_func_expr("correlation", schema,
vec![child1, child2], func)
}
+ AggExprStruct::BloomFilterAgg(expr) => {
+ let child = self.create_expr(expr.child.as_ref().unwrap(),
Arc::clone(&schema))?;
+ let num_items =
+ self.create_expr(expr.num_items.as_ref().unwrap(),
Arc::clone(&schema))?;
+ let num_bits =
+ self.create_expr(expr.num_bits.as_ref().unwrap(),
Arc::clone(&schema))?;
+ let datatype =
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+ let func = AggregateUDF::new_from_impl(BloomFilterAgg::new(
+ Arc::clone(&child),
+ Arc::clone(&num_items),
+ Arc::clone(&num_bits),
+ "bloom_filter_agg",
+ datatype,
+ ));
+ Self::create_aggr_func_expr("bloom_filter_agg", schema,
vec![child], func)
+ }
}
}
diff --git a/native/core/src/execution/datafusion/util/spark_bit_array.rs
b/native/core/src/execution/datafusion/util/spark_bit_array.rs
index 9729627d..68b97d66 100644
--- a/native/core/src/execution/datafusion/util/spark_bit_array.rs
+++ b/native/core/src/execution/datafusion/util/spark_bit_array.rs
@@ -15,6 +15,10 @@
// specific language governing permissions and limitations
// under the License.
+use crate::common::bit;
+use arrow_buffer::ToByteSlice;
+use std::iter::zip;
+
/// A simple bit array implementation that simulates the behavior of Spark's
BitArray which is
/// used in the BloomFilter implementation. Some methods are not implemented
as they are not
/// required for the current use case.
@@ -55,12 +59,50 @@ impl SparkBitArray {
}
pub fn bit_size(&self) -> u64 {
- self.data.len() as u64 * 64
+ self.word_size() as u64 * 64
+ }
+
+ pub fn byte_size(&self) -> usize {
+ self.word_size() * 8
+ }
+
+ pub fn word_size(&self) -> usize {
+ self.data.len()
}
pub fn cardinality(&self) -> usize {
self.bit_count
}
+
+ pub fn to_bytes(&self) -> Vec<u8> {
+ Vec::from(self.data.to_byte_slice())
+ }
+
+ pub fn data(&self) -> Vec<u64> {
+ self.data.clone()
+ }
+
+ // Combines SparkBitArrays, however other is a &[u8] because we anticipate
to come from an
+ // Arrow ScalarValue::Binary which is a byte vector underneath, rather
than a word vector.
+ pub fn merge_bits(&mut self, other: &[u8]) {
+ assert_eq!(self.byte_size(), other.len());
+ let mut bit_count: usize = 0;
+ // For each word, merge the bits into self, and accumulate a new
bit_count.
+ for i in zip(
+ self.data.iter_mut(),
+ other
+ .chunks(8)
+ .map(|chunk| u64::from_ne_bytes(chunk.try_into().unwrap())),
+ ) {
+ *i.0 |= i.1;
+ bit_count += i.0.count_ones() as usize;
+ }
+ self.bit_count = bit_count;
+ }
+}
+
+pub fn num_words(num_bits: i32) -> i32 {
+ bit::ceil(num_bits as usize, 64) as i32
}
#[cfg(test)]
@@ -128,4 +170,67 @@ mod test {
// check cardinality
assert_eq!(array.cardinality(), 6);
}
+
+ #[test]
+ fn test_spark_bit_with_empty_buffer() {
+ let buf = vec![0u64; 4];
+ let array = SparkBitArray::new(buf);
+
+ assert_eq!(array.bit_size(), 256);
+ assert_eq!(array.cardinality(), 0);
+
+ for n in 0..256 {
+ assert!(!array.get(n));
+ }
+ }
+
+ #[test]
+ fn test_spark_bit_with_full_buffer() {
+ let buf = vec![u64::MAX; 4];
+ let array = SparkBitArray::new(buf);
+
+ assert_eq!(array.bit_size(), 256);
+ assert_eq!(array.cardinality(), 256);
+
+ for n in 0..256 {
+ assert!(array.get(n));
+ }
+ }
+
+ #[test]
+ fn test_spark_bit_merge() {
+ let buf1 = vec![0u64; 4];
+ let mut array1 = SparkBitArray::new(buf1);
+ let buf2 = vec![0u64; 4];
+ let mut array2 = SparkBitArray::new(buf2);
+
+ let primes = [
+ 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59,
61, 67, 71, 73, 79, 83,
+ 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151,
157, 163, 167, 173, 179,
+ 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251,
+ ];
+ let fibs = [1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233];
+
+ for n in fibs {
+ array1.set(n);
+ }
+
+ for n in primes {
+ array2.set(n);
+ }
+
+ assert_eq!(array1.cardinality(), fibs.len());
+ assert_eq!(array2.cardinality(), primes.len());
+
+ array1.merge_bits(array2.to_bytes().as_slice());
+
+ for n in fibs {
+ assert!(array1.get(n));
+ }
+
+ for n in primes {
+ assert!(array1.get(n));
+ }
+ assert_eq!(array1.cardinality(), 60);
+ }
}
diff --git a/native/core/src/execution/datafusion/util/spark_bloom_filter.rs
b/native/core/src/execution/datafusion/util/spark_bloom_filter.rs
index 00f71767..22a84d85 100644
--- a/native/core/src/execution/datafusion/util/spark_bloom_filter.rs
+++ b/native/core/src/execution/datafusion/util/spark_bloom_filter.rs
@@ -15,9 +15,12 @@
// specific language governing permissions and limitations
// under the License.
+use crate::execution::datafusion::util::spark_bit_array;
use crate::execution::datafusion::util::spark_bit_array::SparkBitArray;
use arrow_array::{ArrowNativeTypeOp, BooleanArray, Int64Array};
+use arrow_buffer::ToByteSlice;
use datafusion_comet_spark_expr::spark_hash::spark_compatible_murmur3_hash;
+use std::cmp;
const SPARK_BLOOM_FILTER_VERSION_1: i32 = 1;
@@ -30,8 +33,29 @@ pub struct SparkBloomFilter {
num_hash_functions: u32,
}
-impl SparkBloomFilter {
- pub fn new(buf: &[u8]) -> Self {
+pub fn optimal_num_hash_functions(expected_items: i32, num_bits: i32) -> i32 {
+ cmp::max(
+ 1,
+ ((num_bits as f64 / expected_items as f64) * 2.0_f64.ln()).round() as
i32,
+ )
+}
+
+impl From<(i32, i32)> for SparkBloomFilter {
+ /// Creates an empty SparkBloomFilter given number of hash functions and
bits.
+ fn from((num_hash_functions, num_bits): (i32, i32)) -> Self {
+ let num_words = spark_bit_array::num_words(num_bits);
+ let bits = vec![0u64; num_words as usize];
+ Self {
+ bits: SparkBitArray::new(bits),
+ num_hash_functions: num_hash_functions as u32,
+ }
+ }
+}
+
+impl From<&[u8]> for SparkBloomFilter {
+ /// Creates a SparkBloomFilter from a serialized byte array conforming to
Spark's BloomFilter
+ /// binary format version 1.
+ fn from(buf: &[u8]) -> Self {
let mut offset = 0;
let version = read_num_be_bytes!(i32, 4, buf[offset..]);
offset += 4;
@@ -54,6 +78,25 @@ impl SparkBloomFilter {
num_hash_functions: num_hash_functions as u32,
}
}
+}
+
+impl SparkBloomFilter {
+ /// Serializes a SparkBloomFilter to a byte array conforming to Spark's
BloomFilter
+ /// binary format version 1.
+ pub fn spark_serialization(&self) -> Vec<u8> {
+ // There might be a more efficient way to do this, even with all the
endianness stuff.
+ let mut spark_bloom_filter: Vec<u8> = 1_u32.to_be_bytes().to_vec();
+ spark_bloom_filter.append(&mut
self.num_hash_functions.to_be_bytes().to_vec());
+ spark_bloom_filter.append(&mut (self.bits.word_size() as
u32).to_be_bytes().to_vec());
+ let mut filter_state: Vec<u64> = self.bits.data();
+ for i in filter_state.iter_mut() {
+ *i = i.to_be();
+ }
+ // Does it make sense to do a std::mem::take of filter_state here?
Unclear to me if a deep
+ // copy of filter_state as a Vec<u64> to a Vec<u8> is happening here.
+ spark_bloom_filter.append(&mut
Vec::from(filter_state.to_byte_slice()));
+ spark_bloom_filter
+ }
pub fn put_long(&mut self, item: i64) -> bool {
// Here we first hash the input long element into 2 int hash values,
h1 and h2, then produce
@@ -94,4 +137,17 @@ impl SparkBloomFilter {
.map(|v| v.map(|x| self.might_contain_long(x)))
.collect()
}
+
+ pub fn state_as_bytes(&self) -> Vec<u8> {
+ self.bits.to_bytes()
+ }
+
+ pub fn merge_filter(&mut self, other: &[u8]) {
+ assert_eq!(
+ other.len(),
+ self.bits.byte_size(),
+ "Cannot merge SparkBloomFilters with different lengths."
+ );
+ self.bits.merge_bits(other);
+ }
}
diff --git a/native/proto/src/proto/expr.proto
b/native/proto/src/proto/expr.proto
index 1a3e3c9f..796ca5be 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -101,6 +101,7 @@ message AggExpr {
Variance variance = 13;
Stddev stddev = 14;
Correlation correlation = 15;
+ BloomFilterAgg bloomFilterAgg = 16;
}
}
@@ -192,6 +193,13 @@ message Correlation {
DataType datatype = 4;
}
+message BloomFilterAgg {
+ Expr child = 1;
+ Expr numItems = 2;
+ Expr numBits = 3;
+ DataType datatype = 4;
+}
+
message Literal {
oneof value {
bool bool_val = 1;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index c6e692cc..3805d418 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Average, BitAndAgg, BitOrAgg, BitXorAgg, Complete, Corr, Count, CovPopulation,
CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum,
VariancePop, VarianceSamp}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, Complete, Corr,
Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial,
StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
@@ -760,6 +760,39 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
withInfo(aggExpr, child1, child2)
None
}
+
+ case bloom_filter @ BloomFilterAggregate(child, numItems, numBits, _, _)
=>
+ // We ignore mutableAggBufferOffset and inputAggBufferOffset because
they are
+ // implementation details for Spark's ObjectHashAggregate.
+ val childExpr = exprToProto(child, inputs, binding)
+ val numItemsExpr = exprToProto(numItems, inputs, binding)
+ val numBitsExpr = exprToProto(numBits, inputs, binding)
+ val dataType = serializeDataType(bloom_filter.dataType)
+
+ // TODO: Support more types
+ // https://github.com/apache/datafusion-comet/issues/1023
+ if (childExpr.isDefined &&
+ child.dataType
+ .isInstanceOf[LongType] &&
+ numItemsExpr.isDefined &&
+ numBitsExpr.isDefined &&
+ dataType.isDefined) {
+ val bloomFilterAggBuilder =
ExprOuterClass.BloomFilterAgg.newBuilder()
+ bloomFilterAggBuilder.setChild(childExpr.get)
+ bloomFilterAggBuilder.setNumItems(numItemsExpr.get)
+ bloomFilterAggBuilder.setNumBits(numBitsExpr.get)
+ bloomFilterAggBuilder.setDatatype(dataType.get)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setBloomFilterAgg(bloomFilterAggBuilder)
+ .build())
+ } else {
+ withInfo(aggExpr, child, numItems, numBits)
+ None
+ }
+
case fn =>
val msg = s"unsupported Spark aggregate function: ${fn.prettyName}"
emitWarning(msg)
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 05aa2372..78f59cbe 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -31,10 +31,10 @@ import org.scalatest.Tag
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{AnalysisException, Column, CometTestBase,
DataFrame, DataFrameWriter, Row, SaveMode}
-import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics,
CatalogTable}
-import org.apache.spark.sql.catalyst.expressions.Hex
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo,
Hex}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode,
BloomFilterAggregate}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec,
CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec,
CometHashAggregateExec, CometHashJoinExec, CometProjectExec, CometScanExec,
CometSortExec, CometSortMergeJoinExec, CometSparkToColumnarExec,
CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec,
SQLExecution, UnionExec}
@@ -911,6 +911,29 @@ class CometExecSuite extends CometTestBase {
}
}
+ test("bloom_filter_agg") {
+ val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg")
+ spark.sessionState.functionRegistry.registerFunction(
+ funcId_bloom_filter_agg,
+ new ExpressionInfo(classOf[BloomFilterAggregate].getName,
"bloom_filter_agg"),
+ (children: Seq[Expression]) =>
+ children.size match {
+ case 1 => new BloomFilterAggregate(children.head)
+ case 2 => new BloomFilterAggregate(children.head, children(1))
+ case 3 => new BloomFilterAggregate(children.head, children(1),
children(2))
+ })
+
+ withParquetTable(
+ (0 until 100)
+ .map(_ => (Random.nextInt(), Random.nextInt() % 5)),
+ "tbl") {
+ val df = sql("SELECT bloom_filter_agg(cast(_2 as long)) FROM tbl")
+ checkSparkAnswerAndOperator(df)
+ }
+
+ spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)
+ }
+
test("sort (non-global)") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") {
val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc)
diff --git
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala
index 7205484e..3dd930f6 100644
---
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala
+++
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala
@@ -22,6 +22,9 @@ package org.apache.spark.sql.benchmark
import org.apache.spark.SparkConf
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
+import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.internal.SQLConf
import org.apache.comet.{CometConf, CometSparkSessionExtensions}
@@ -222,23 +225,77 @@ object CometExecBenchmark extends CometBenchmarkBase {
}
}
- override def runCometBenchmark(mainArgs: Array[String]): Unit = {
- runBenchmarkWithTable("Subquery", 1024 * 1024 * 10) { v =>
- subqueryExecBenchmark(v)
- }
+ // BloomFilterAgg takes an argument for the expected number of distinct
values, which determines filter size and
+ // number of hash functions. We use the cardinality as a hint to the
aggregate, otherwise the default Spark values
+ // make a big filter with a lot of hash functions.
+ def bloomFilterAggregate(values: Int, cardinality: Int): Unit = {
+ val benchmark =
+ new Benchmark(
+ s"BloomFilterAggregate Exec (cardinality $cardinality)",
+ values,
+ output = output)
- runBenchmarkWithTable("Expand", 1024 * 1024 * 10) { v =>
- expandExecBenchmark(v)
- }
+ val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg")
+ spark.sessionState.functionRegistry.registerFunction(
+ funcId_bloom_filter_agg,
+ new ExpressionInfo(classOf[BloomFilterAggregate].getName,
"bloom_filter_agg"),
+ (children: Seq[Expression]) => new BloomFilterAggregate(children.head,
children(1)))
+
+ withTempPath { dir =>
+ withTempTable("parquetV1Table") {
+ prepareTable(dir, spark.sql(s"SELECT floor(rand() * $cardinality) as
key FROM $tbl"))
- runBenchmarkWithTable("Project + Filter", 1024 * 1024 * 10) { v =>
- for (fractionOfZeros <- List(0.0, 0.50, 0.95)) {
- numericFilterExecBenchmark(v, fractionOfZeros)
+ val query =
+ s"SELECT bloom_filter_agg(cast(key as long), cast($cardinality as
long)) FROM parquetV1Table"
+
+ benchmark.addCase("SQL Parquet - Spark (BloomFilterAgg)") { _ =>
+ spark.sql(query).noop()
+ }
+
+ benchmark.addCase("SQL Parquet - Comet (Scan) (BloomFilterAgg)") { _ =>
+ withSQLConf(CometConf.COMET_ENABLED.key -> "true") {
+ spark.sql(query).noop()
+ }
+ }
+
+ benchmark.addCase("SQL Parquet - Comet (Scan, Exec) (BloomFilterAgg)")
{ _ =>
+ withSQLConf(
+ CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_ENABLED.key -> "true") {
+ spark.sql(query).noop()
+ }
+ }
+
+ benchmark.run()
}
}
- runBenchmarkWithTable("Sort", 1024 * 1024 * 10) { v =>
- sortExecBenchmark(v)
+ spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)
+ }
+
+ override def runCometBenchmark(mainArgs: Array[String]): Unit = {
+// runBenchmarkWithTable("Subquery", 1024 * 1024 * 10) { v =>
+// subqueryExecBenchmark(v)
+// }
+//
+// runBenchmarkWithTable("Expand", 1024 * 1024 * 10) { v =>
+// expandExecBenchmark(v)
+// }
+//
+// runBenchmarkWithTable("Project + Filter", 1024 * 1024 * 10) { v =>
+// for (fractionOfZeros <- List(0.0, 0.50, 0.95)) {
+// numericFilterExecBenchmark(v, fractionOfZeros)
+// }
+// }
+//
+// runBenchmarkWithTable("Sort", 1024 * 1024 * 10) { v =>
+// sortExecBenchmark(v)
+// }
+
+ runBenchmarkWithTable("BloomFilterAggregate", 1024 * 1024 * 10) { v =>
+ for (card <- List(100, 1024, 1024 * 1024)) {
+ bloomFilterAggregate(v, card)
+ }
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]