This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new b22e4d2648 Improve performance of `last_value` by implementing special
`GroupsAccumulator` (#15542)
b22e4d2648 is described below
commit b22e4d26480202247022bf2c94ba629a07e3275b
Author: UBarney <[email protected]>
AuthorDate: Thu Apr 10 22:49:06 2025 +0800
Improve performance of `last_value` by implementing special
`GroupsAccumulator` (#15542)
* Improve performance of `last_value` by implementing special
`GroupsAccumulator`
* less diff
---
datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs | 26 +++
.../tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs | 4 +-
datafusion/functions-aggregate/src/first_last.rs | 231 +++++++++++++++++++--
datafusion/sqllogictest/test_files/group_by.slt | 27 ++-
4 files changed, 270 insertions(+), 18 deletions(-)
diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
index 4ba63e9991..ff3b66986c 100644
--- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
@@ -120,6 +120,32 @@ async fn test_first_val() {
.await;
}
+#[tokio::test(flavor = "multi_thread")]
+async fn test_last_val() {
+ let mut data_gen_config = baseline_config();
+
+ for i in 0..data_gen_config.columns.len() {
+ if data_gen_config.columns[i].get_max_num_distinct().is_none() {
+ data_gen_config.columns[i] = data_gen_config.columns[i]
+ .clone()
+ // Minimize the chance of identical values in the order by
columns to make the test more stable
+ .with_max_num_distinct(usize::MAX);
+ }
+ }
+
+ let query_builder = QueryBuilder::new()
+ .with_table_name("fuzz_table")
+ .with_aggregate_function("last_value")
+ .with_aggregate_arguments(data_gen_config.all_columns())
+ .set_group_by_columns(data_gen_config.all_columns());
+
+ AggregationFuzzerBuilder::from(data_gen_config)
+ .add_query_builder(query_builder)
+ .build()
+ .run()
+ .await;
+}
+
#[tokio::test(flavor = "multi_thread")]
async fn test_max() {
let data_gen_config = baseline_config();
diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
index 1e97ddcdb1..53e9288ab4 100644
--- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
@@ -503,7 +503,9 @@ impl QueryBuilder {
let distinct = if *is_distinct { "DISTINCT " } else { "" };
alias_gen += 1;
- let (order_by, null_opt) = if function_name.eq("first_value") {
+ let (order_by, null_opt) = if function_name.eq("first_value")
+ || function_name.eq("last_value")
+ {
(
self.order_by(&order_by_black_list), /* Among the order by
columns, at most one group by column can be included to avoid all order by
column values being identical */
self.null_opt(),
diff --git a/datafusion/functions-aggregate/src/first_last.rs
b/datafusion/functions-aggregate/src/first_last.rs
index 28e6a8723d..6465436375 100644
--- a/datafusion/functions-aggregate/src/first_last.rs
+++ b/datafusion/functions-aggregate/src/first_last.rs
@@ -166,6 +166,7 @@ impl AggregateUDFImpl for FirstValue {
}
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+ // TODO: extract to function
use DataType::*;
matches!(
args.return_type,
@@ -193,6 +194,7 @@ impl AggregateUDFImpl for FirstValue {
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
+ // TODO: extract to function
fn create_accumulator<T>(
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>>
@@ -210,6 +212,7 @@ impl AggregateUDFImpl for FirstValue {
args.ignore_nulls,
args.return_type,
&ordering_dtypes,
+ true,
)?))
}
@@ -258,10 +261,12 @@ impl AggregateUDFImpl for FirstValue {
create_accumulator::<Time64NanosecondType>(args)
}
- _ => internal_err!(
- "GroupsAccumulator not supported for first({})",
- args.return_type
- ),
+ _ => {
+ internal_err!(
+ "GroupsAccumulator not supported for first_value({})",
+ args.return_type
+ )
+ }
}
}
@@ -291,6 +296,7 @@ impl AggregateUDFImpl for FirstValue {
}
}
+// TODO: rename to PrimitiveGroupsAccumulator
struct FirstPrimitiveGroupsAccumulator<T>
where
T: ArrowPrimitiveType + Send,
@@ -316,12 +322,16 @@ where
// buffer for `get_filtered_min_of_each_group`
// filter_min_of_each_group_buf.0[group_idx] -> idx_in_val
// only valid if filter_min_of_each_group_buf.1[group_idx] == true
+ // TODO: rename to extreme_of_each_group_buf
min_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
// =========== option ============
// Stores the applicable ordering requirement.
ordering_req: LexOrdering,
+ // true: take first element in an aggregation group according to the
requested ordering.
+ // false: take last element in an aggregation group according to the
requested ordering.
+ pick_first_in_group: bool,
// derived from `ordering_req`.
sort_options: Vec<SortOptions>,
// Stores whether incoming data already satisfies the ordering requirement.
@@ -342,6 +352,7 @@ where
ignore_nulls: bool,
data_type: &DataType,
ordering_dtypes: &[DataType],
+ pick_first_in_group: bool,
) -> Result<Self> {
let requirement_satisfied = ordering_req.is_empty();
@@ -365,6 +376,7 @@ where
is_sets: BooleanBufferBuilder::new(0),
size_of_orderings: 0,
min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
+ pick_first_in_group,
})
}
@@ -391,8 +403,13 @@ where
assert!(new_ordering_values.len() == self.ordering_req.len());
let current_ordering = &self.orderings[group_idx];
- compare_rows(current_ordering, new_ordering_values, &self.sort_options)
- .map(|x| x.is_gt())
+ compare_rows(current_ordering, new_ordering_values,
&self.sort_options).map(|x| {
+ if self.pick_first_in_group {
+ x.is_gt()
+ } else {
+ x.is_lt()
+ }
+ })
}
fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
@@ -501,10 +518,10 @@ where
.map(ScalarValue::size_of_vec)
.sum::<usize>()
}
-
/// Returns a vector of tuples `(group_idx, idx_in_val)` representing the
index of the
/// minimum value in `orderings` for each group, using lexicographical
comparison.
/// Values are filtered using `opt_filter` and `is_set_arr` if provided.
+ /// TODO: rename to get_filtered_extreme_of_each_group
fn get_filtered_min_of_each_group(
&mut self,
orderings: &[ArrayRef],
@@ -556,15 +573,19 @@ where
}
let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx);
- if is_valid
- && comparator
- .compare(self.min_of_each_group_buf.0[group_idx],
idx_in_val)
- .is_gt()
- {
- self.min_of_each_group_buf.0[group_idx] = idx_in_val;
- } else if !is_valid {
+
+ if !is_valid {
self.min_of_each_group_buf.1.set_bit(group_idx, true);
self.min_of_each_group_buf.0[group_idx] = idx_in_val;
+ } else {
+ let ordering = comparator
+ .compare(self.min_of_each_group_buf.0[group_idx],
idx_in_val);
+
+ if (ordering.is_gt() && self.pick_first_in_group)
+ || (ordering.is_lt() && !self.pick_first_in_group)
+ {
+ self.min_of_each_group_buf.0[group_idx] = idx_in_val;
+ }
}
}
@@ -1052,6 +1073,109 @@ impl AggregateUDFImpl for LastValue {
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
+
+ fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+ use DataType::*;
+ matches!(
+ args.return_type,
+ Int8 | Int16
+ | Int32
+ | Int64
+ | UInt8
+ | UInt16
+ | UInt32
+ | UInt64
+ | Float16
+ | Float32
+ | Float64
+ | Decimal128(_, _)
+ | Decimal256(_, _)
+ | Date32
+ | Date64
+ | Time32(_)
+ | Time64(_)
+ | Timestamp(_, _)
+ )
+ }
+
+ fn create_groups_accumulator(
+ &self,
+ args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
+ fn create_accumulator<T>(
+ args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>>
+ where
+ T: ArrowPrimitiveType + Send,
+ {
+ let ordering_dtypes = args
+ .ordering_req
+ .iter()
+ .map(|e| e.expr.data_type(args.schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
+ args.ordering_req.clone(),
+ args.ignore_nulls,
+ args.return_type,
+ &ordering_dtypes,
+ false,
+ )?))
+ }
+
+ match args.return_type {
+ DataType::Int8 => create_accumulator::<Int8Type>(args),
+ DataType::Int16 => create_accumulator::<Int16Type>(args),
+ DataType::Int32 => create_accumulator::<Int32Type>(args),
+ DataType::Int64 => create_accumulator::<Int64Type>(args),
+ DataType::UInt8 => create_accumulator::<UInt8Type>(args),
+ DataType::UInt16 => create_accumulator::<UInt16Type>(args),
+ DataType::UInt32 => create_accumulator::<UInt32Type>(args),
+ DataType::UInt64 => create_accumulator::<UInt64Type>(args),
+ DataType::Float16 => create_accumulator::<Float16Type>(args),
+ DataType::Float32 => create_accumulator::<Float32Type>(args),
+ DataType::Float64 => create_accumulator::<Float64Type>(args),
+
+ DataType::Decimal128(_, _) =>
create_accumulator::<Decimal128Type>(args),
+ DataType::Decimal256(_, _) =>
create_accumulator::<Decimal256Type>(args),
+
+ DataType::Timestamp(TimeUnit::Second, _) => {
+ create_accumulator::<TimestampSecondType>(args)
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, _) => {
+ create_accumulator::<TimestampMillisecondType>(args)
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ create_accumulator::<TimestampMicrosecondType>(args)
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+ create_accumulator::<TimestampNanosecondType>(args)
+ }
+
+ DataType::Date32 => create_accumulator::<Date32Type>(args),
+ DataType::Date64 => create_accumulator::<Date64Type>(args),
+ DataType::Time32(TimeUnit::Second) => {
+ create_accumulator::<Time32SecondType>(args)
+ }
+ DataType::Time32(TimeUnit::Millisecond) => {
+ create_accumulator::<Time32MillisecondType>(args)
+ }
+
+ DataType::Time64(TimeUnit::Microsecond) => {
+ create_accumulator::<Time64MicrosecondType>(args)
+ }
+ DataType::Time64(TimeUnit::Nanosecond) => {
+ create_accumulator::<Time64NanosecondType>(args)
+ }
+
+ _ => {
+ internal_err!(
+ "GroupsAccumulator not supported for last_value({})",
+ args.return_type
+ )
+ }
+ }
+ }
}
#[derive(Debug)]
@@ -1411,6 +1535,7 @@ mod tests {
true,
&DataType::Int64,
&[DataType::Int64],
+ true,
)?;
let mut val_with_orderings = {
@@ -1485,7 +1610,7 @@ mod tests {
}
#[test]
- fn test_frist_group_acc_size_of_ordering() -> Result<()> {
+ fn test_group_acc_size_of_ordering() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Int64, true),
@@ -1504,6 +1629,7 @@ mod tests {
true,
&DataType::Int64,
&[DataType::Int64],
+ true,
)?;
let val_with_orderings = {
@@ -1563,4 +1689,79 @@ mod tests {
Ok(())
}
+
+ #[test]
+ fn test_last_group_acc() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int64, true),
+ Field::new("b", DataType::Int64, true),
+ Field::new("c", DataType::Int64, true),
+ Field::new("d", DataType::Int32, true),
+ Field::new("e", DataType::Boolean, true),
+ ]));
+
+ let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
+ expr: col("c", &schema).unwrap(),
+ options: SortOptions::default(),
+ }]);
+
+ let mut group_acc =
FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
+ sort_key,
+ true,
+ &DataType::Int64,
+ &[DataType::Int64],
+ false,
+ )?;
+
+ let mut val_with_orderings = {
+ let mut val_with_orderings = Vec::<ArrayRef>::new();
+
+ let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3),
Some(-6)]));
+ let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
+
+ val_with_orderings.push(vals);
+ val_with_orderings.push(orderings);
+
+ val_with_orderings
+ };
+
+ group_acc.update_batch(
+ &val_with_orderings,
+ &[0, 1, 2, 1],
+ Some(&BooleanArray::from(vec![true, true, false, true])),
+ 3,
+ )?;
+
+ let state = group_acc.state(EmitTo::All)?;
+
+ let expected_state: Vec<Arc<dyn Array>> = vec![
+ Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
+ Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
+ Arc::new(BooleanArray::from(vec![true, true, false])),
+ ];
+ assert_eq!(state, expected_state);
+
+ group_acc.merge_batch(
+ &state,
+ &[0, 1, 2],
+ Some(&BooleanArray::from(vec![true, false, false])),
+ 3,
+ )?;
+
+ val_with_orderings.clear();
+ val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
+ val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
+
+ group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
+
+ let binding = group_acc.evaluate(EmitTo::All)?;
+ let eval_result =
binding.as_any().downcast_ref::<Int64Array>().unwrap();
+
+ let expect: PrimitiveArray<Int64Type> =
+ Int64Array::from(vec![Some(1), Some(66), Some(6), None]);
+
+ assert_eq!(eval_result, &expect);
+
+ Ok(())
+ }
}
diff --git a/datafusion/sqllogictest/test_files/group_by.slt
b/datafusion/sqllogictest/test_files/group_by.slt
index 4c4999a364..9e67018ecd 100644
--- a/datafusion/sqllogictest/test_files/group_by.slt
+++ b/datafusion/sqllogictest/test_files/group_by.slt
@@ -2232,7 +2232,7 @@ physical_plan
03)----StreamingTableExec: partition_sizes=1, projection=[a, b, c],
infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST,
c@2 ASC NULLS LAST]
query III
-SELECT a, b, LAST_VALUE(c) as last_c
+SELECT a, b, LAST_VALUE(c order by c) as last_c
FROM annotated_data_infinite2
GROUP BY a, b
----
@@ -2706,6 +2706,29 @@ select k, first_value(val order by o) respect NULLS from
first_null group by k;
1 1
+statement ok
+CREATE TABLE last_null (
+ k INT,
+ val INT,
+ o int
+ ) as VALUES
+ (0, NULL, 9),
+ (0, 1, 1),
+ (1, 1, 1);
+
+query II rowsort
+select k, last_value(val order by o) IGNORE NULLS from last_null group by k;
+----
+0 1
+1 1
+
+query II rowsort
+select k, last_value(val order by o) respect NULLS from last_null group by k;
+----
+0 NULL
+1 1
+
+
query TT
EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts,
FIRST_VALUE(amount ORDER BY amount ASC) AS fv1,
@@ -3775,7 +3798,7 @@ ORDER BY x;
2 2
query II
-SELECT y, LAST_VALUE(x)
+SELECT y, LAST_VALUE(x order by x desc)
FROM FOO
GROUP BY y
ORDER BY y;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]