This is an automated email from the ASF dual-hosted git repository. mneumann pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push: new 56394e0a50 feat: Support distinct window for sum (#16943) 56394e0a50 is described below commit 56394e0a501284559016d5dde83f140c761cab12 Author: Qi Zhu <821684...@qq.com> AuthorDate: Tue Jul 29 18:26:45 2025 +0800 feat: Support distinct window for sum (#16943) * feat: support sum distinct for window * fmt * fmt * fix test --- datafusion/ffi/src/udaf/mod.rs | 3 +- datafusion/functions-aggregate/src/sum.rs | 127 ++++++++++++++++++++++++-- datafusion/sqllogictest/test_files/window.slt | 85 +++++++++++++++-- 3 files changed, 198 insertions(+), 17 deletions(-) diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index 17116e2461..66e1c28bb9 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -720,6 +720,7 @@ mod tests { let foreign_udaf = create_test_foreign_udaf(Sum::new())?; let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + // Note: sum distinct is only support Int64 until now let acc_args = AccumulatorArgs { return_field: Field::new("f", DataType::Float64, true).into(), schema: &schema, @@ -727,7 +728,7 @@ mod tests { order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)], is_reversed: false, name: "round_trip", - is_distinct: true, + is_distinct: false, exprs: &[col("a", &schema)?], }; diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 9495e087d2..97c0bbb976 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -34,7 +34,7 @@ use arrow::datatypes::{ }; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::{ - exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, + exec_err, not_impl_err, utils::take_function_args, HashMap, Result, ScalarValue, }; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; @@ -243,12 +243,23 @@ impl AggregateUDFImpl for Sum { &self, args: AccumulatorArgs, ) -> Result<Box<dyn Accumulator>> { - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone()))) - }; + if args.is_distinct { + // distinct path: use our sliding‐window distinct‐sum + macro_rules! helper_distinct { + ($t:ty, $dt:expr) => { + Ok(Box::new(SlidingDistinctSumAccumulator::try_new(&$dt)?)) + }; + } + downcast_sum!(args, helper_distinct) + } else { + // non‐distinct path: existing sliding sum + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone()))) + }; + } + downcast_sum!(args, helper) } - downcast_sum!(args, helper) } fn reverse_expr(&self) -> ReversedUDAF { @@ -477,3 +488,107 @@ impl<T: ArrowPrimitiveType> Accumulator for DistinctSumAccumulator<T> { size_of_val(self) + self.values.capacity() * size_of::<T::Native>() } } + +/// A sliding‐window accumulator for `SUM(DISTINCT)` over Int64 columns. +/// Maintains a running sum so that `evaluate()` is O(1). +#[derive(Debug)] +pub struct SlidingDistinctSumAccumulator { + /// Map each distinct value → its current count in the window + counts: HashMap<i64, usize, RandomState>, + /// Running sum of all distinct keys currently in the window + sum: i64, + /// Data type (must be Int64) + data_type: DataType, +} + +impl SlidingDistinctSumAccumulator { + /// Create a new accumulator; only `DataType::Int64` is supported. + pub fn try_new(data_type: &DataType) -> Result<Self> { + // TODO support other numeric types + if *data_type != DataType::Int64 { + return exec_err!("SlidingDistinctSumAccumulator only supports Int64"); + } + Ok(Self { + counts: HashMap::default(), + sum: 0, + data_type: data_type.clone(), + }) + } +} + +impl Accumulator for SlidingDistinctSumAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = values[0].as_primitive::<Int64Type>(); + for &v in arr.values() { + let cnt = self.counts.entry(v).or_insert(0); + if *cnt == 0 { + // first occurrence in window + self.sum = self.sum.wrapping_add(v); + } + *cnt += 1; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result<ScalarValue> { + // O(1) wrap of running sum + Ok(ScalarValue::Int64(Some(self.sum))) + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn state(&mut self) -> Result<Vec<ScalarValue>> { + // Serialize distinct keys for cross-partition merge if needed + let keys = self + .counts + .keys() + .cloned() + .map(Some) + .map(ScalarValue::Int64) + .collect::<Vec<_>>(); + Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable( + &keys, + &self.data_type, + ))]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // Merge distinct keys from other partitions + let list_arr = states[0].as_list::<i32>(); + for maybe_inner in list_arr.iter().flatten() { + for idx in 0..maybe_inner.len() { + if let ScalarValue::Int64(Some(v)) = + ScalarValue::try_from_array(&*maybe_inner, idx)? + { + let cnt = self.counts.entry(v).or_insert(0); + if *cnt == 0 { + self.sum = self.sum.wrapping_add(v); + } + *cnt += 1; + } + } + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = values[0].as_primitive::<Int64Type>(); + for &v in arr.values() { + if let Some(cnt) = self.counts.get_mut(&v) { + *cnt -= 1; + if *cnt == 0 { + // last copy leaving window + self.sum = self.sum.wrapping_sub(v); + self.counts.remove(&v); + } + } + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } +} diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index bed9121eec..44677fd5b9 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -5715,17 +5715,82 @@ EXPLAIN SELECT RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW ) AS distinct_count FROM table_test_distinct_count -ODER BY k, time; +ORDER BY k, time; ---- logical_plan -01)Projection: oder.k, oder.time, count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS normal_count, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS distinct_count -02)--WindowAggr: windowExpr=[[count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 12 [...] -03)----SubqueryAlias: oder +01)Sort: table_test_distinct_count.k ASC NULLS LAST, table_test_distinct_count.time ASC NULLS LAST +02)--Projection: table_test_distinct_count.k, table_test_distinct_count.time, count(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS normal_count, count(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS distinct_count +03)----WindowAggr: windowExpr=[[count(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS count(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW, count(DISTINCT table_test [...] 04)------TableScan: table_test_distinct_count projection=[k, v, time] physical_plan -01)ProjectionExec: expr=[k@0 as k, time@2 as time, count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@3 as normal_count, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@4 as distinct_count] -02)--BoundedWindowAggExec: wdw=[count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRE [...] -03)----SortExec: expr=[k@0 ASC NULLS LAST, time@2 ASC NULLS LAST], preserve_partitioning=[true] -04)------CoalesceBatchesExec: target_batch_size=1 -05)--------RepartitionExec: partitioning=Hash([k@0], 2), input_partitions=2 -06)----------DataSourceExec: partitions=2, partition_sizes=[5, 4] +01)SortPreservingMergeExec: [k@0 ASC NULLS LAST, time@1 ASC NULLS LAST] +02)--ProjectionExec: expr=[k@0 as k, time@2 as time, count(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@3 as normal_count, count(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@4 as distinct_count] +03)----BoundedWindowAggExec: wdw=[count(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "count(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, [...] +04)------SortExec: expr=[k@0 ASC NULLS LAST, time@2 ASC NULLS LAST], preserve_partitioning=[true] +05)--------CoalesceBatchesExec: target_batch_size=1 +06)----------RepartitionExec: partitioning=Hash([k@0], 2), input_partitions=2 +07)------------DataSourceExec: partitions=2, partition_sizes=[5, 4] + + +# Add testing for distinct sum +query TPII +SELECT + k, + time, + SUM(v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS sum_v, + SUM(DISTINCT v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS sum_distinct_v +FROM table_test_distinct_count +ORDER BY k, time; +---- +a 1970-01-01T00:01:00Z 1 1 +a 1970-01-01T00:02:00Z 2 1 +a 1970-01-01T00:03:00Z 5 3 +a 1970-01-01T00:03:00Z 5 3 +a 1970-01-01T00:04:00Z 5 3 +b 1970-01-01T00:01:00Z 3 3 +b 1970-01-01T00:02:00Z 6 3 +b 1970-01-01T00:03:00Z 14 7 +b 1970-01-01T00:03:00Z 14 7 + + + +query TT +EXPLAIN SELECT + k, + time, + SUM(v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS sum_v, + SUM(DISTINCT v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS sum_distinct_v +FROM table_test_distinct_count +ORDER BY k, time; +---- +logical_plan +01)Sort: table_test_distinct_count.k ASC NULLS LAST, table_test_distinct_count.time ASC NULLS LAST +02)--Projection: table_test_distinct_count.k, table_test_distinct_count.time, sum(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS sum_v, sum(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS sum_distinct_v +03)----WindowAggr: windowExpr=[[sum(__common_expr_1) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS sum(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW, sum(DISTINCT __common_expr_1) PARTITION B [...] +04)------Projection: CAST(table_test_distinct_count.v AS Int64) AS __common_expr_1, table_test_distinct_count.k, table_test_distinct_count.time +05)--------TableScan: table_test_distinct_count projection=[k, v, time] +physical_plan +01)SortPreservingMergeExec: [k@0 ASC NULLS LAST, time@1 ASC NULLS LAST] +02)--ProjectionExec: expr=[k@1 as k, time@2 as time, sum(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@3 as sum_v, sum(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@4 as sum_distinct_v] +03)----BoundedWindowAggExec: wdw=[sum(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "sum(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, meta [...] +04)------SortExec: expr=[k@1 ASC NULLS LAST, time@2 ASC NULLS LAST], preserve_partitioning=[true] +05)--------CoalesceBatchesExec: target_batch_size=1 +06)----------RepartitionExec: partitioning=Hash([k@1], 2), input_partitions=2 +07)------------ProjectionExec: expr=[CAST(v@1 AS Int64) as __common_expr_1, k@0 as k, time@2 as time] +08)--------------DataSourceExec: partitions=2, partition_sizes=[5, 4] --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org