This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 1103a6c1 feat: Support more types with BloomFilterAgg (#1039)
1103a6c1 is described below
commit 1103a6c193b92cf0ce655b7e18a5a44aa4332476
Author: Matt Butrovich <[email protected]>
AuthorDate: Thu Oct 31 15:14:59 2024 -0400
feat: Support more types with BloomFilterAgg (#1039)
* Add support for other integer types. Next up is strings.
* Test refactor.
* Add string support for agg. Should probably go back and add support to
BloomFilterMightContain.
* Fix typo.
* Minor test refactor to reduce copy-paste.
* Add type checking in plan conversion, and fix signature on native side.
---
.../datafusion/expressions/bloom_filter_agg.rs | 35 ++++++++++++++++++----
.../datafusion/util/spark_bloom_filter.rs | 17 +++++++++++
.../org/apache/comet/serde/QueryPlanSerde.scala | 14 ++++++---
.../org/apache/comet/exec/CometExecSuite.scala | 8 +++--
4 files changed, 63 insertions(+), 11 deletions(-)
diff --git
a/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
b/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
index ed64b80e..e6528a56 100644
--- a/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
+++ b/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
@@ -62,7 +62,17 @@ impl BloomFilterAgg {
assert!(matches!(data_type, DataType::Binary));
Self {
name: name.into(),
- signature: Signature::exact(vec![DataType::Int64],
Volatility::Immutable),
+ signature: Signature::uniform(
+ 1,
+ vec![
+ DataType::Int8,
+ DataType::Int16,
+ DataType::Int32,
+ DataType::Int64,
+ DataType::Utf8,
+ ],
+ Volatility::Immutable,
+ ),
expr,
num_items: extract_i32_from_literal(num_items),
num_bits: extract_i32_from_literal(num_bits),
@@ -112,10 +122,25 @@ impl Accumulator for SparkBloomFilter {
(0..arr.len()).try_for_each(|index| {
let v = ScalarValue::try_from_array(arr, index)?;
- if let ScalarValue::Int64(Some(value)) = v {
- self.put_long(value);
- } else {
- unreachable!()
+ match v {
+ ScalarValue::Int8(Some(value)) => {
+ self.put_long(value as i64);
+ }
+ ScalarValue::Int16(Some(value)) => {
+ self.put_long(value as i64);
+ }
+ ScalarValue::Int32(Some(value)) => {
+ self.put_long(value as i64);
+ }
+ ScalarValue::Int64(Some(value)) => {
+ self.put_long(value);
+ }
+ ScalarValue::Utf8(Some(value)) => {
+ self.put_binary(value.as_bytes());
+ }
+ _ => {
+ unreachable!()
+ }
}
Ok(())
})
diff --git a/native/core/src/execution/datafusion/util/spark_bloom_filter.rs
b/native/core/src/execution/datafusion/util/spark_bloom_filter.rs
index 22a84d85..35fa23b4 100644
--- a/native/core/src/execution/datafusion/util/spark_bloom_filter.rs
+++ b/native/core/src/execution/datafusion/util/spark_bloom_filter.rs
@@ -115,6 +115,23 @@ impl SparkBloomFilter {
bit_changed
}
+ pub fn put_binary(&mut self, item: &[u8]) -> bool {
+ // Here we first hash the input long element into 2 int hash values,
h1 and h2, then produce
+ // n hash values by `h1 + i * h2` with 1 <= i <= num_hash_functions.
+ let h1 = spark_compatible_murmur3_hash(item, 0);
+ let h2 = spark_compatible_murmur3_hash(item, h1);
+ let bit_size = self.bits.bit_size() as i32;
+ let mut bit_changed = false;
+ for i in 1..=self.num_hash_functions {
+ let mut combined_hash = (h1 as i32).add_wrapping((i as
i32).mul_wrapping(h2 as i32));
+ if combined_hash < 0 {
+ combined_hash = !combined_hash;
+ }
+ bit_changed |= self.bits.set((combined_hash % bit_size) as usize)
+ }
+ bit_changed
+ }
+
pub fn might_contain_long(&self, item: i64) -> bool {
let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0);
let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1);
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 3805d418..abb138b0 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -769,11 +769,17 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
val numBitsExpr = exprToProto(numBits, inputs, binding)
val dataType = serializeDataType(bloom_filter.dataType)
- // TODO: Support more types
- // https://github.com/apache/datafusion-comet/issues/1023
if (childExpr.isDefined &&
- child.dataType
- .isInstanceOf[LongType] &&
+ (child.dataType
+ .isInstanceOf[ByteType] ||
+ child.dataType
+ .isInstanceOf[ShortType] ||
+ child.dataType
+ .isInstanceOf[IntegerType] ||
+ child.dataType
+ .isInstanceOf[LongType] ||
+ child.dataType
+ .isInstanceOf[StringType]) &&
numItemsExpr.isDefined &&
numBitsExpr.isDefined &&
dataType.isDefined) {
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index a720842c..99007d0c 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -946,8 +946,12 @@ class CometExecSuite extends CometTestBase {
(0 until 100)
.map(_ => (Random.nextInt(), Random.nextInt() % 5)),
"tbl") {
- val df = sql("SELECT bloom_filter_agg(cast(_2 as long)) FROM tbl")
- checkSparkAnswerAndOperator(df)
+
+ (if (isSpark35Plus) Seq("tinyint", "short", "int", "long", "string")
else Seq("long"))
+ .foreach { input_type =>
+ val df = sql(f"SELECT bloom_filter_agg(cast(_2 as $input_type)) FROM
tbl")
+ checkSparkAnswerAndOperator(df)
+ }
}
spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]