This is an automated email from the ASF dual-hosted git repository. sunchao pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new 969f683 feat: Support BloomFilterMightContain expr (#179) 969f683 is described below commit 969f683c188fbcf2956da784b10777445f9adbc0 Author: advancedxy <xian...@apache.org> AuthorDate: Fri Mar 15 05:24:16 2024 +0800 feat: Support BloomFilterMightContain expr (#179) --- core/src/common/bit.rs | 11 ++ .../expressions/bloom_filter_might_contain.rs | 152 +++++++++++++++++++++ core/src/execution/datafusion/expressions/mod.rs | 1 + core/src/execution/datafusion/mod.rs | 1 + core/src/execution/datafusion/planner.rs | 10 ++ core/src/execution/datafusion/spark_hash.rs | 2 +- core/src/execution/datafusion/{ => util}/mod.rs | 9 +- .../execution/datafusion/util/spark_bit_array.rs | 131 ++++++++++++++++++ .../datafusion/util/spark_bloom_filter.rs | 98 +++++++++++++ core/src/execution/proto/expr.proto | 6 + pom.xml | 10 ++ spark/pom.xml | 18 +++ .../org/apache/comet/serde/QueryPlanSerde.scala | 17 +++ .../apache/comet/shims/ShimQueryPlanSerde.scala | 7 +- .../apache/comet/CometExpression3_3PlusSuite.scala | 106 ++++++++++++++ 15 files changed, 570 insertions(+), 9 deletions(-) diff --git a/core/src/common/bit.rs b/core/src/common/bit.rs index 4af560f..f736347 100644 --- a/core/src/common/bit.rs +++ b/core/src/common/bit.rs @@ -131,6 +131,17 @@ pub fn read_num_bytes_u32(size: usize, src: &[u8]) -> u32 { trailing_bits(v as u64, size * 8) as u32 } +/// Similar to the `read_num_bytes` but read nums from bytes in big-endian order +/// This is used to read bytes from Java's OutputStream which writes bytes in big-endian +macro_rules! read_num_be_bytes { + ($ty:ty, $size:expr, $src:expr) => {{ + debug_assert!($size <= $src.len()); + let mut buffer = <$ty as $crate::common::bit::FromBytes>::Buffer::default(); + buffer.as_mut()[..$size].copy_from_slice(&$src[..$size]); + <$ty>::from_be_bytes(buffer) + }}; +} + /// Converts value `val` of type `T` to a byte vector, by reading `num_bytes` from `val`. /// NOTE: if `val` is less than the size of `T` then it can be truncated. #[inline] diff --git a/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs b/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs new file mode 100644 index 0000000..dd90cd8 --- /dev/null +++ b/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + execution::datafusion::util::spark_bloom_filter::SparkBloomFilter, parquet::data_type::AsBytes, +}; +use arrow::record_batch::RecordBatch; +use arrow_array::cast::as_primitive_array; +use arrow_schema::{DataType, Schema}; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, PhysicalExpr}; +use std::{ + any::Any, + fmt::Display, + hash::{Hash, Hasher}, + sync::Arc, +}; + +/// A physical expression that checks if a value might be in a bloom filter. It corresponds to the +/// Spark's `BloomFilterMightContain` expression. +#[derive(Debug, Hash)] +pub struct BloomFilterMightContain { + pub bloom_filter_expr: Arc<dyn PhysicalExpr>, + pub value_expr: Arc<dyn PhysicalExpr>, + bloom_filter: Option<SparkBloomFilter>, +} + +impl Display for BloomFilterMightContain { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "BloomFilterMightContain [bloom_filter_expr: {}, value_expr: {}]", + self.bloom_filter_expr, self.value_expr + ) + } +} + +impl PartialEq<dyn Any> for BloomFilterMightContain { + fn eq(&self, _other: &dyn Any) -> bool { + down_cast_any_ref(_other) + .downcast_ref::<Self>() + .map(|other| { + self.bloom_filter_expr.eq(&other.bloom_filter_expr) + && self.value_expr.eq(&other.value_expr) + }) + .unwrap_or(false) + } +} + +fn evaluate_bloom_filter( + bloom_filter_expr: &Arc<dyn PhysicalExpr>, +) -> Result<Option<SparkBloomFilter>> { + // bloom_filter_expr must be a literal/scalar subquery expression, so we can evaluate it + // with an empty batch with empty schema + let batch = RecordBatch::new_empty(Arc::new(Schema::empty())); + let bloom_filter_bytes = bloom_filter_expr.evaluate(&batch)?; + match bloom_filter_bytes { + ColumnarValue::Scalar(ScalarValue::Binary(v)) => { + Ok(v.map(|v| SparkBloomFilter::new(v.as_bytes()))) + } + _ => internal_err!("Bloom filter expression should be evaluated as a scalar binary value"), + } +} + +impl BloomFilterMightContain { + pub fn try_new( + bloom_filter_expr: Arc<dyn PhysicalExpr>, + value_expr: Arc<dyn PhysicalExpr>, + ) -> Result<Self> { + // early evaluate the bloom_filter_expr to get the actual bloom filter + let bloom_filter = evaluate_bloom_filter(&bloom_filter_expr)?; + Ok(Self { + bloom_filter_expr, + value_expr, + bloom_filter, + }) + } +} + +impl PhysicalExpr for BloomFilterMightContain { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result<DataType> { + Ok(DataType::Boolean) + } + + fn nullable(&self, _input_schema: &Schema) -> Result<bool> { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> { + self.bloom_filter + .as_ref() + .map(|spark_filter| { + let values = self.value_expr.evaluate(batch)?; + match values { + ColumnarValue::Array(array) => { + let boolean_array = + spark_filter.might_contain_longs(as_primitive_array(&array)); + Ok(ColumnarValue::Array(Arc::new(boolean_array))) + } + ColumnarValue::Scalar(ScalarValue::Int64(v)) => { + let result = v.map(|v| spark_filter.might_contain_long(v)); + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result))) + } + _ => internal_err!("value expression should be int64 type"), + } + }) + .unwrap_or_else(|| { + // when the bloom filter is null, we should return null for all the input + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))) + }) + } + + fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> { + vec![self.bloom_filter_expr.clone(), self.value_expr.clone()] + } + + fn with_new_children( + self: Arc<Self>, + children: Vec<Arc<dyn PhysicalExpr>>, + ) -> Result<Arc<dyn PhysicalExpr>> { + Ok(Arc::new(BloomFilterMightContain::try_new( + children[0].clone(), + children[1].clone(), + )?)) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.bloom_filter_expr.hash(&mut s); + self.value_expr.hash(&mut s); + self.hash(&mut s); + } +} diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index cfc3125..69cdf3e 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -26,6 +26,7 @@ pub mod scalar_funcs; pub use normalize_nan::NormalizeNaNAndZero; pub mod avg; pub mod avg_decimal; +pub mod bloom_filter_might_contain; pub mod strings; pub mod subquery; pub mod sum_decimal; diff --git a/core/src/execution/datafusion/mod.rs b/core/src/execution/datafusion/mod.rs index f9fafeb..c464eee 100644 --- a/core/src/execution/datafusion/mod.rs +++ b/core/src/execution/datafusion/mod.rs @@ -22,3 +22,4 @@ mod operators; pub mod planner; pub(crate) mod shuffle_writer; mod spark_hash; +mod util; diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index d52ec80..57c126b 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -56,6 +56,7 @@ use crate::{ avg::Avg, avg_decimal::AvgDecimal, bitwise_not::BitwiseNotExpr, + bloom_filter_might_contain::BloomFilterMightContain, cast::Cast, checkoverflow::CheckOverflow, if_expr::IfExpr, @@ -534,6 +535,15 @@ impl PhysicalPlanner { let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); Ok(Arc::new(Subquery::new(self.exec_context_id, id, data_type))) } + ExprStruct::BloomFilterMightContain(expr) => { + let bloom_filter_expr = + self.create_expr(expr.bloom_filter.as_ref().unwrap(), input_schema.clone())?; + let value_expr = self.create_expr(expr.value.as_ref().unwrap(), input_schema)?; + Ok(Arc::new(BloomFilterMightContain::try_new( + bloom_filter_expr, + value_expr, + )?)) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/core/src/execution/datafusion/spark_hash.rs b/core/src/execution/datafusion/spark_hash.rs index aeefccf..1d8d1f2 100644 --- a/core/src/execution/datafusion/spark_hash.rs +++ b/core/src/execution/datafusion/spark_hash.rs @@ -32,7 +32,7 @@ use datafusion::{ }; #[inline] -fn spark_compatible_murmur3_hash<T: AsRef<[u8]>>(data: T, seed: u32) -> u32 { +pub(crate) fn spark_compatible_murmur3_hash<T: AsRef<[u8]>>(data: T, seed: u32) -> u32 { #[inline] fn mix_k1(mut k1: i32) -> i32 { k1 = k1.mul_wrapping(0xcc9e2d51u32 as i32); diff --git a/core/src/execution/datafusion/mod.rs b/core/src/execution/datafusion/util/mod.rs similarity index 85% copy from core/src/execution/datafusion/mod.rs copy to core/src/execution/datafusion/util/mod.rs index f9fafeb..75b763a 100644 --- a/core/src/execution/datafusion/mod.rs +++ b/core/src/execution/datafusion/util/mod.rs @@ -15,10 +15,5 @@ // specific language governing permissions and limitations // under the License. -//! Native execution through DataFusion - -mod expressions; -mod operators; -pub mod planner; -pub(crate) mod shuffle_writer; -mod spark_hash; +pub mod spark_bit_array; +pub mod spark_bloom_filter; diff --git a/core/src/execution/datafusion/util/spark_bit_array.rs b/core/src/execution/datafusion/util/spark_bit_array.rs new file mode 100644 index 0000000..9729627 --- /dev/null +++ b/core/src/execution/datafusion/util/spark_bit_array.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// A simple bit array implementation that simulates the behavior of Spark's BitArray which is +/// used in the BloomFilter implementation. Some methods are not implemented as they are not +/// required for the current use case. +#[derive(Debug, Hash)] +pub struct SparkBitArray { + data: Vec<u64>, + bit_count: usize, +} + +impl SparkBitArray { + pub fn new(buf: Vec<u64>) -> Self { + let num_bits = buf.iter().map(|x| x.count_ones() as usize).sum(); + Self { + data: buf, + bit_count: num_bits, + } + } + + pub fn set(&mut self, index: usize) -> bool { + if !self.get(index) { + // see the get method for the explanation of the shift operators + self.data[index >> 6] |= 1u64 << (index & 0x3f); + self.bit_count += 1; + true + } else { + false + } + } + + pub fn get(&self, index: usize) -> bool { + // Java version: (data[(int) (index >> 6)] & (1L << (index))) != 0 + // Rust and Java have different semantics for the shift operators. Java's shift operators + // explicitly mask the right-hand operand with 0x3f [1], while Rust's shift operators does + // not do this, it will panic with shift left with overflow for large right-hand operand. + // To fix this, we need to mask the right-hand operand with 0x3f in the rust side. + // [1]: https://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.19 + (self.data[index >> 6] & (1u64 << (index & 0x3f))) != 0 + } + + pub fn bit_size(&self) -> u64 { + self.data.len() as u64 * 64 + } + + pub fn cardinality(&self) -> usize { + self.bit_count + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_spark_bit_array() { + let buf = vec![0u64; 4]; + let mut array = SparkBitArray::new(buf); + assert_eq!(array.bit_size(), 256); + assert_eq!(array.cardinality(), 0); + + assert!(!array.get(0)); + assert!(!array.get(1)); + assert!(!array.get(63)); + assert!(!array.get(64)); + assert!(!array.get(65)); + assert!(!array.get(127)); + assert!(!array.get(128)); + assert!(!array.get(129)); + + assert!(array.set(0)); + assert!(array.set(1)); + assert!(array.set(63)); + assert!(array.set(64)); + assert!(array.set(65)); + assert!(array.set(127)); + assert!(array.set(128)); + assert!(array.set(129)); + + assert_eq!(array.cardinality(), 8); + assert_eq!(array.bit_size(), 256); + + assert!(array.get(0)); + // already set so should return false + assert!(!array.set(0)); + + // not set values should return false for get + assert!(!array.get(2)); + assert!(!array.get(62)); + } + + #[test] + fn test_spark_bit_with_non_empty_buffer() { + let buf = vec![8u64; 4]; + let mut array = SparkBitArray::new(buf); + assert_eq!(array.bit_size(), 256); + assert_eq!(array.cardinality(), 4); + + // already set bits should return true + assert!(array.get(3)); + assert!(array.get(67)); + assert!(array.get(131)); + assert!(array.get(195)); + + // other unset bits should return false + assert!(!array.get(0)); + assert!(!array.get(1)); + + // set bits + assert!(array.set(0)); + assert!(array.set(1)); + + // check cardinality + assert_eq!(array.cardinality(), 6); + } +} diff --git a/core/src/execution/datafusion/util/spark_bloom_filter.rs b/core/src/execution/datafusion/util/spark_bloom_filter.rs new file mode 100644 index 0000000..22957a1 --- /dev/null +++ b/core/src/execution/datafusion/util/spark_bloom_filter.rs @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::execution::datafusion::{ + spark_hash::spark_compatible_murmur3_hash, util::spark_bit_array::SparkBitArray, +}; +use arrow_array::{ArrowNativeTypeOp, BooleanArray, Int64Array}; + +const SPARK_BLOOM_FILTER_VERSION_1: i32 = 1; + +/// A Bloom filter implementation that simulates the behavior of Spark's BloomFilter. +/// It's not a complete implementation of Spark's BloomFilter, but just add the minimum +/// methods to support mightContainsLong in the native side. +#[derive(Debug, Hash)] +pub struct SparkBloomFilter { + bits: SparkBitArray, + num_hash_functions: u32, +} + +impl SparkBloomFilter { + pub fn new(buf: &[u8]) -> Self { + let mut offset = 0; + let version = read_num_be_bytes!(i32, 4, buf[offset..]); + offset += 4; + assert_eq!( + version, SPARK_BLOOM_FILTER_VERSION_1, + "Unsupported BloomFilter version: {}, expecting version: {}", + version, SPARK_BLOOM_FILTER_VERSION_1 + ); + let num_hash_functions = read_num_be_bytes!(i32, 4, buf[offset..]); + offset += 4; + let num_words = read_num_be_bytes!(i32, 4, buf[offset..]); + offset += 4; + let mut bits = vec![0u64; num_words as usize]; + for i in 0..num_words { + bits[i as usize] = read_num_be_bytes!(i64, 8, buf[offset..]) as u64; + offset += 8; + } + Self { + bits: SparkBitArray::new(bits), + num_hash_functions: num_hash_functions as u32, + } + } + + pub fn put_long(&mut self, item: i64) -> bool { + // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce + // n hash values by `h1 + i * h2` with 1 <= i <= num_hash_functions. + let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0); + let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1); + let bit_size = self.bits.bit_size() as i32; + let mut bit_changed = false; + for i in 1..=self.num_hash_functions { + let mut combined_hash = (h1 as i32).add_wrapping((i as i32).mul_wrapping(h2 as i32)); + if combined_hash < 0 { + combined_hash = !combined_hash; + } + bit_changed |= self.bits.set((combined_hash % bit_size) as usize) + } + bit_changed + } + + pub fn might_contain_long(&self, item: i64) -> bool { + let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0); + let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1); + let bit_size = self.bits.bit_size() as i32; + for i in 1..=self.num_hash_functions { + let mut combined_hash = (h1 as i32).add_wrapping((i as i32).mul_wrapping(h2 as i32)); + if combined_hash < 0 { + combined_hash = !combined_hash; + } + if !self.bits.get((combined_hash % bit_size) as usize) { + return false; + } + } + true + } + + pub fn might_contain_longs(&self, items: &Int64Array) -> BooleanArray { + items + .iter() + .map(|v| v.map(|x| self.might_contain_long(x))) + .collect() + } +} diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index e8d35d1..58f607f 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -76,6 +76,7 @@ message Expr { Abs abs = 49; Subquery subquery = 50; UnboundReference unbound = 51; + BloomFilterMightContain bloom_filter_might_contain = 52; } } @@ -432,6 +433,11 @@ message Subquery { DataType datatype = 2; } +message BloomFilterMightContain { + Expr bloom_filter = 1; + Expr value = 2; +} + enum SortDirection { Ascending = 0; Descending = 1; diff --git a/pom.xml b/pom.xml index 7077aa3..bdc55ec 100644 --- a/pom.xml +++ b/pom.xml @@ -86,6 +86,7 @@ under the License. -Djdk.reflect.useDirectMethodHandle=false </extraJavaTestArgs> <argLine>-ea -Xmx4g -Xss4m ${extraJavaTestArgs}</argLine> + <additional.test.source>spark-3.3-plus</additional.test.source> </properties> <dependencyManagement> @@ -494,6 +495,8 @@ under the License. <spark.version>3.2.2</spark.version> <spark.version.short>3.2</spark.version.short> <parquet.version>1.12.0</parquet.version> + <!-- we don't add special test suits for spark-3.2, so a not existed dir is specified--> + <additional.test.source>not-needed-yet</additional.test.source> </properties> </profile> @@ -504,6 +507,7 @@ under the License. <spark.version>3.3.2</spark.version> <spark.version.short>3.3</spark.version.short> <parquet.version>1.12.0</parquet.version> + <additional.test.source>spark-3.3-plus</additional.test.source> </properties> </profile> @@ -513,6 +517,7 @@ under the License. <scala.version>2.12.17</scala.version> <spark.version.short>3.4</spark.version.short> <parquet.version>1.13.1</parquet.version> + <additional.test.source>spark-3.3-plus</additional.test.source> </properties> </profile> @@ -777,6 +782,11 @@ under the License. <artifactId>jacoco-maven-plugin</artifactId> <version>${jacoco.version}</version> </plugin> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>build-helper-maven-plugin</artifactId> + <version>3.2.0</version> + </plugin> </plugins> </pluginManagement> <plugins> diff --git a/spark/pom.xml b/spark/pom.xml index 7e54fde..31d80bb 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -233,6 +233,24 @@ under the License. <groupId>net.alchim31.maven</groupId> <artifactId>scala-maven-plugin</artifactId> </plugin> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>build-helper-maven-plugin</artifactId> + <executions> + <execution> + <id>add-test-source</id> + <phase>generate-test-sources</phase> + <goals> + <goal>add-test-source</goal> + </goals> + <configuration> + <sources> + <source>src/test/${additional.test.source}</source> + </sources> + </configuration> + </execution> + </executions> + </plugin> </plugins> </build> diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 87b4dff..abb4f0a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1626,6 +1626,23 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { "make_decimal", DecimalType(precision, scale), childExpr) + case b @ BinaryExpression(_, _) if isBloomFilterMightContain(b) => + val bloomFilter = b.left + val value = b.right + val bloomFilterExpr = exprToProtoInternal(bloomFilter, inputs) + val valueExpr = exprToProtoInternal(value, inputs) + if (bloomFilterExpr.isDefined && valueExpr.isDefined) { + val builder = ExprOuterClass.BloomFilterMightContain.newBuilder() + builder.setBloomFilter(bloomFilterExpr.get) + builder.setValue(valueExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setBloomFilterMightContain(builder) + .build()) + } else { + None + } case e => emitWarning(s"unsupported Spark expression: '$e' of class '${e.getClass.getName}") diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala index c47b399..7bdf2c0 100644 --- a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala @@ -19,7 +19,7 @@ package org.apache.comet.shims -import org.apache.spark.sql.catalyst.expressions.BinaryArithmetic +import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, BinaryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate trait ShimQueryPlanSerde { @@ -44,4 +44,9 @@ trait ShimQueryPlanSerde { failOnError.head } } + + // TODO: delete after drop Spark 3.2 support + def isBloomFilterMightContain(binary: BinaryExpression): Boolean = { + binary.getClass.getName == "org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain" + } } diff --git a/spark/src/test/spark-3.3-plus/org/apache/comet/CometExpression3_3PlusSuite.scala b/spark/src/test/spark-3.3-plus/org/apache/comet/CometExpression3_3PlusSuite.scala new file mode 100644 index 0000000..6102777 --- /dev/null +++ b/spark/src/test/spark-3.3-plus/org/apache/comet/CometExpression3_3PlusSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.{Column, CometTestBase} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.{BloomFilterMightContain, Expression, ExpressionInfo} +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.util.sketch.BloomFilter + +import java.io.ByteArrayOutputStream +import scala.util.Random + +class CometExpression3_3PlusSuite extends CometTestBase with AdaptiveSparkPlanHelper { + import testImplicits._ + + val func_might_contain = new FunctionIdentifier("might_contain") + + override def beforeAll(): Unit = { + super.beforeAll() + // Register 'might_contain' to builtin. + spark.sessionState.functionRegistry.registerFunction(func_might_contain, + new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"), + (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1))) + } + + override def afterAll(): Unit = { + spark.sessionState.functionRegistry.dropFunction(func_might_contain) + super.afterAll() + } + + test("test BloomFilterMightContain can take a constant value input") { + val table = "test" + + withTable(table) { + sql(s"create table $table(col1 long, col2 int) using parquet") + sql(s"insert into $table values (201, 1)") + checkSparkAnswerAndOperator( + s""" + |SELECT might_contain( + |X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267', col1) FROM $table + |""".stripMargin) + } + } + + test("test NULL inputs for BloomFilterMightContain") { + val table = "test" + + withTable(table) { + sql(s"create table $table(col1 long, col2 int) using parquet") + sql(s"insert into $table values (201, 1), (null, 2)") + checkSparkAnswerAndOperator( + s""" + |SELECT might_contain(null, null) both_null, + | might_contain(null, 1L) null_bf, + | might_contain( + | X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267', col1) null_value + | FROM $table + |""".stripMargin) + } + } + + test("test BloomFilterMightContain from random input") { + val (longs, bfBytes) = bloomFilterFromRandomInput(10000, 10000) + val table = "test" + + withTable(table) { + sql(s"create table $table(col1 long, col2 binary) using parquet") + spark.createDataset(longs).map(x => (x, bfBytes)).toDF("col1", "col2").write.insertInto(table) + val df = spark.table(table).select(new Column(BloomFilterMightContain(lit(bfBytes).expr, col("col1").expr))) + checkSparkAnswerAndOperator(df) + // check with scalar subquery + checkSparkAnswerAndOperator( + s""" + |SELECT might_contain((select first(col2) as col2 from $table), col1) FROM $table + |""".stripMargin) + } + } + + private def bloomFilterFromRandomInput(expectedItems: Long, expectedBits: Long): (Seq[Long], Array[Byte]) = { + val bf = BloomFilter.create(expectedItems, expectedBits) + val longs = (0 until expectedItems.toInt).map(_ => Random.nextLong()) + longs.foreach(bf.put) + val os = new ByteArrayOutputStream() + bf.writeTo(os) + (longs, os.toByteArray) + } +}