This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 1624d63070 perf: Add support for `GroupsAccumulator` to `string_agg`
(#21154)
1624d63070 is described below
commit 1624d63070c05dde448a14727bd20764fc7887ad
Author: Neil Conway <[email protected]>
AuthorDate: Thu Mar 26 16:51:15 2026 -0400
perf: Add support for `GroupsAccumulator` to `string_agg` (#21154)
## Which issue does this PR close?
- Closes #17789.
## Rationale for this change
`string_agg` previously didn't support the `GroupsAccumulator` API;
adding support for it can significantly improve performance,
particularly when there are many groups.
Benchmarks (M4 Max):
- string_agg_query_group_by_few_groups (~10): 645 µs → 564 µs, -11%
- string_agg_query_group_by_mid_groups (~1,000): 2,692 µs → 871 µs, -68%
- string_agg_query_group_by_many_groups (~65,000): 16,606 µs → 1,147 µs,
-93%
## What changes are included in this PR?
* Add end-to-end benchmark for `string_agg`
* Implement `GroupsAccumulator` API for `string_agg`
* Add unit tests
* Minor code cleanup for existing `string_agg` code paths
## Are these changes tested?
Yes.
## Are there any user-facing changes?
No, other than a change to an error message string.
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/core/benches/aggregate_query_sql.rs | 33 ++
datafusion/functions-aggregate/src/string_agg.rs | 405 ++++++++++++++++++++---
datafusion/sqllogictest/test_files/aggregate.slt | 2 +-
3 files changed, 384 insertions(+), 56 deletions(-)
diff --git a/datafusion/core/benches/aggregate_query_sql.rs
b/datafusion/core/benches/aggregate_query_sql.rs
index 402ac9c717..d7e24aceba 100644
--- a/datafusion/core/benches/aggregate_query_sql.rs
+++ b/datafusion/core/benches/aggregate_query_sql.rs
@@ -295,6 +295,39 @@ fn criterion_benchmark(c: &mut Criterion) {
)
})
});
+
+ c.bench_function("string_agg_query_group_by_few_groups", |b| {
+ b.iter(|| {
+ query(
+ ctx.clone(),
+ &rt,
+ "SELECT u64_narrow, string_agg(utf8, ',') \
+ FROM t GROUP BY u64_narrow",
+ )
+ })
+ });
+
+ c.bench_function("string_agg_query_group_by_mid_groups", |b| {
+ b.iter(|| {
+ query(
+ ctx.clone(),
+ &rt,
+ "SELECT u64_mid, string_agg(utf8, ',') \
+ FROM t GROUP BY u64_mid",
+ )
+ })
+ });
+
+ c.bench_function("string_agg_query_group_by_many_groups", |b| {
+ b.iter(|| {
+ query(
+ ctx.clone(),
+ &rt,
+ "SELECT u64_wide, string_agg(utf8, ',') \
+ FROM t GROUP BY u64_wide",
+ )
+ })
+ });
}
criterion_group!(benches, criterion_benchmark);
diff --git a/datafusion/functions-aggregate/src/string_agg.rs
b/datafusion/functions-aggregate/src/string_agg.rs
index 6f1a37302f..ea3914b1e3 100644
--- a/datafusion/functions-aggregate/src/string_agg.rs
+++ b/datafusion/functions-aggregate/src/string_agg.rs
@@ -20,23 +20,24 @@
use std::any::Any;
use std::hash::Hash;
use std::mem::size_of_val;
+use std::sync::Arc;
use crate::array_agg::ArrayAgg;
-use arrow::array::ArrayRef;
+use arrow::array::{ArrayRef, AsArray, BooleanArray, LargeStringArray};
use arrow::datatypes::{DataType, Field, FieldRef};
-use datafusion_common::cast::{
- as_generic_string_array, as_string_array, as_string_view_array,
-};
+use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
use datafusion_common::{
Result, ScalarValue, internal_datafusion_err, internal_err, not_impl_err,
};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
- Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature,
Volatility,
+ Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator,
Signature,
+ TypeSignature, Volatility,
};
use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
+use
datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls;
use datafusion_macros::user_doc;
use datafusion_physical_expr::expressions::Literal;
@@ -117,6 +118,27 @@ impl StringAgg {
array_agg: Default::default(),
}
}
+
+ /// Extract the delimiter string from the second argument expression.
+ fn extract_delimiter(args: &AccumulatorArgs) -> Result<String> {
+ let Some(lit) = args.exprs[1].as_any().downcast_ref::<Literal>() else {
+ return not_impl_err!("string_agg delimiter must be a string
literal");
+ };
+
+ if lit.value().is_null() {
+ return Ok(String::new());
+ }
+
+ match lit.value().try_as_str() {
+ Some(s) => Ok(s.unwrap_or("").to_string()),
+ None => {
+ not_impl_err!(
+ "string_agg not supported for delimiter \"{}\"",
+ lit.value()
+ )
+ }
+ }
+ }
}
impl Default for StringAgg {
@@ -125,8 +147,10 @@ impl Default for StringAgg {
}
}
-/// If there is no `distinct` and `order by` required by the `string_agg`
call, a
-/// more efficient accumulator `SimpleStringAggAccumulator` will be used.
+/// Three accumulation strategies depending on query shape:
+/// - No DISTINCT / ORDER BY with GROUP BY: `StringAggGroupsAccumulator`
+/// - No DISTINCT / ORDER BY without GROUP BY: `SimpleStringAggAccumulator`
+/// - With DISTINCT or ORDER BY: `StringAggAccumulator` (delegates to
`ArrayAgg`)
impl AggregateUDFImpl for StringAgg {
fn as_any(&self) -> &dyn Any {
self
@@ -145,11 +169,7 @@ impl AggregateUDFImpl for StringAgg {
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
- // See comments in `impl AggregateUDFImpl ...` for more detail
- let no_order_no_distinct =
- (args.ordering_fields.is_empty()) && (!args.is_distinct);
- if no_order_no_distinct {
- // Case `SimpleStringAggAccumulator`
+ if !args.is_distinct && args.ordering_fields.is_empty() {
Ok(vec![
Field::new(
format_state_name(args.name, "string_agg"),
@@ -159,40 +179,16 @@ impl AggregateUDFImpl for StringAgg {
.into(),
])
} else {
- // Case `StringAggAccumulator`
self.array_agg.state_fields(args)
}
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
- let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>()
else {
- return not_impl_err!(
- "The second argument of the string_agg function must be a
string literal"
- );
- };
-
- let delimiter = if lit.value().is_null() {
- // If the second argument (the delimiter that joins strings) is
NULL, join
- // on an empty string. (e.g. [a, b, c] => "abc").
- ""
- } else if let Some(lit_string) = lit.value().try_as_str() {
- lit_string.unwrap_or("")
- } else {
- return not_impl_err!(
- "StringAgg not supported for delimiter \"{}\"",
- lit.value()
- );
- };
-
- // See comments in `impl AggregateUDFImpl ...` for more detail
- let no_order_no_distinct =
- acc_args.order_bys.is_empty() && (!acc_args.is_distinct);
+ let delimiter = Self::extract_delimiter(&acc_args)?;
- if no_order_no_distinct {
- // simple case (more efficient)
- Ok(Box::new(SimpleStringAggAccumulator::new(delimiter)))
+ if !acc_args.is_distinct && acc_args.order_bys.is_empty() {
+ Ok(Box::new(SimpleStringAggAccumulator::new(&delimiter)))
} else {
- // general case
let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
return_field: Field::new(
"f",
@@ -215,7 +211,7 @@ impl AggregateUDFImpl for StringAgg {
Ok(Box::new(StringAggAccumulator::new(
array_agg_acc,
- delimiter,
+ &delimiter,
)))
}
}
@@ -224,6 +220,18 @@ impl AggregateUDFImpl for StringAgg {
datafusion_expr::ReversedUDAF::Reversed(string_agg_udaf())
}
+ fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+ !args.is_distinct && args.order_bys.is_empty()
+ }
+
+ fn create_groups_accumulator(
+ &self,
+ args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
+ let delimiter = Self::extract_delimiter(&args)?;
+ Ok(Box::new(StringAggGroupsAccumulator::new(delimiter)))
+ }
+
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
@@ -315,10 +323,136 @@ fn filter_index<T: Clone>(values: &[T], index: usize) ->
Vec<T> {
.collect::<Vec<_>>()
}
-/// StringAgg accumulator for the simple case (no order or distinct specified)
-/// This accumulator is more efficient than `StringAggAccumulator`
-/// because it accumulates the string directly,
-/// whereas `StringAggAccumulator` uses `ArrayAggAccumulator`.
+/// GroupsAccumulator for `string_agg` without DISTINCT or ORDER BY.
+#[derive(Debug)]
+struct StringAggGroupsAccumulator {
+ /// The delimiter placed between concatenated values.
+ delimiter: String,
+ /// Accumulated string per group. `None` means no values have been seen
+ /// (the group's output will be NULL).
+ /// A potential improvement is to avoid this String allocation
+ /// See <https://github.com/apache/datafusion/issues/21156>
+ values: Vec<Option<String>>,
+ /// Running total of string data bytes across all groups.
+ total_data_bytes: usize,
+}
+
+impl StringAggGroupsAccumulator {
+ fn new(delimiter: String) -> Self {
+ Self {
+ delimiter,
+ values: Vec::new(),
+ total_data_bytes: 0,
+ }
+ }
+
+ fn append_batch<'a>(
+ &mut self,
+ iter: impl Iterator<Item = Option<&'a str>>,
+ group_indices: &[usize],
+ ) {
+ for (opt_value, &group_idx) in iter.zip(group_indices.iter()) {
+ if let Some(value) = opt_value {
+ match &mut self.values[group_idx] {
+ Some(existing) => {
+ let added = self.delimiter.len() + value.len();
+ existing.reserve(added);
+ existing.push_str(&self.delimiter);
+ existing.push_str(value);
+ self.total_data_bytes += added;
+ }
+ slot @ None => {
+ *slot = Some(value.to_string());
+ self.total_data_bytes += value.len();
+ }
+ }
+ }
+ }
+ }
+}
+
+impl GroupsAccumulator for StringAggGroupsAccumulator {
+ fn update_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ self.values.resize(total_num_groups, None);
+ let array = apply_filter_as_nulls(&values[0], opt_filter)?;
+ match array.data_type() {
+ DataType::Utf8 => {
+ self.append_batch(array.as_string::<i32>().iter(),
group_indices)
+ }
+ DataType::LargeUtf8 => {
+ self.append_batch(array.as_string::<i64>().iter(),
group_indices)
+ }
+ DataType::Utf8View => {
+ self.append_batch(array.as_string_view().iter(), group_indices)
+ }
+ other => {
+ return internal_err!("string_agg unexpected data type:
{other}");
+ }
+ }
+ Ok(())
+ }
+
+ fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
+ let to_emit = emit_to.take_needed(&mut self.values);
+ let emitted_bytes: usize = to_emit
+ .iter()
+ .filter_map(|opt| opt.as_ref().map(|s| s.len()))
+ .sum();
+ self.total_data_bytes -= emitted_bytes;
+
+ let result: ArrayRef = Arc::new(LargeStringArray::from(to_emit));
+ Ok(result)
+ }
+
+ fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
+ self.evaluate(emit_to).map(|arr| vec![arr])
+ }
+
+ fn merge_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ // State is always LargeUtf8, which update_batch already handles.
+ self.update_batch(values, group_indices, opt_filter, total_num_groups)
+ }
+
+ fn convert_to_state(
+ &self,
+ values: &[ArrayRef],
+ opt_filter: Option<&BooleanArray>,
+ ) -> Result<Vec<ArrayRef>> {
+ let input = apply_filter_as_nulls(&values[0], opt_filter)?;
+ let result = if input.data_type() == &DataType::LargeUtf8 {
+ input
+ } else {
+ arrow::compute::cast(&input, &DataType::LargeUtf8)?
+ };
+ Ok(vec![result])
+ }
+
+ fn supports_convert_to_state(&self) -> bool {
+ true
+ }
+
+ fn size(&self) -> usize {
+ self.total_data_bytes
+ + self.values.capacity() * size_of::<Option<String>>()
+ + self.delimiter.capacity()
+ + size_of_val(self)
+ }
+}
+
+/// Per-row accumulator for `string_agg` without DISTINCT or ORDER BY. Used
for
+/// non-grouped aggregation; grouped queries use
[`StringAggGroupsAccumulator`].
#[derive(Debug)]
pub(crate) struct SimpleStringAggAccumulator {
delimiter: String,
@@ -331,7 +465,7 @@ impl SimpleStringAggAccumulator {
pub fn new(delimiter: &str) -> Self {
Self {
delimiter: delimiter.to_string(),
- accumulated_string: "".to_string(),
+ accumulated_string: String::new(),
has_value: false,
}
}
@@ -361,18 +495,11 @@ impl Accumulator for SimpleStringAggAccumulator {
})?;
match string_arr.data_type() {
- DataType::Utf8 => {
- let array = as_string_array(string_arr)?;
- self.append_strings(array.iter());
- }
+ DataType::Utf8 =>
self.append_strings(string_arr.as_string::<i32>().iter()),
DataType::LargeUtf8 => {
- let array = as_generic_string_array::<i64>(string_arr)?;
- self.append_strings(array.iter());
- }
- DataType::Utf8View => {
- let array = as_string_view_array(string_arr)?;
- self.append_strings(array.iter());
+ self.append_strings(string_arr.as_string::<i64>().iter())
}
+ DataType::Utf8View =>
self.append_strings(string_arr.as_string_view().iter()),
other => {
return internal_err!(
"Planner should ensure string_agg first argument is
Utf8-like, found {other}"
@@ -662,4 +789,172 @@ mod tests {
acc1.merge_batch(&intermediate_state)?;
Ok(acc1)
}
+
+ // ---------------------------------------------------------------
+ // Tests for StringAggGroupsAccumulator
+ // ---------------------------------------------------------------
+
+ fn make_groups_acc(delimiter: &str) -> StringAggGroupsAccumulator {
+ StringAggGroupsAccumulator::new(delimiter.to_string())
+ }
+
+ /// Helper: evaluate and downcast to LargeStringArray
+ fn evaluate_groups(
+ acc: &mut StringAggGroupsAccumulator,
+ emit_to: EmitTo,
+ ) -> Vec<Option<String>> {
+ let result = acc.evaluate(emit_to).unwrap();
+ let arr = result.as_any().downcast_ref::<LargeStringArray>().unwrap();
+ arr.iter().map(|v| v.map(|s| s.to_string())).collect()
+ }
+
+ #[test]
+ fn groups_basic() -> Result<()> {
+ let mut acc = make_groups_acc(",");
+
+ // 6 rows, 3 groups: group 0 gets "a","d"; group 1 gets "b","e"; group
2 gets "c","f"
+ let values: ArrayRef =
+ Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d", "e",
"f"]));
+ let group_indices = vec![0, 1, 2, 0, 1, 2];
+ acc.update_batch(&[values], &group_indices, None, 3)?;
+
+ let result = evaluate_groups(&mut acc, EmitTo::All);
+ assert_eq!(
+ result,
+ vec![
+ Some("a,d".to_string()),
+ Some("b,e".to_string()),
+ Some("c,f".to_string()),
+ ]
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn groups_with_nulls() -> Result<()> {
+ let mut acc = make_groups_acc("|");
+
+ // Group 0: "a", NULL, "c" → "a|c"
+ // Group 1: NULL, "b" → "b"
+ // Group 2: NULL only → NULL
+ let values: ArrayRef = Arc::new(LargeStringArray::from(vec![
+ Some("a"),
+ None,
+ Some("c"),
+ None,
+ Some("b"),
+ None,
+ ]));
+ let group_indices = vec![0, 1, 0, 2, 1, 2];
+ acc.update_batch(&[values], &group_indices, None, 3)?;
+
+ let result = evaluate_groups(&mut acc, EmitTo::All);
+ assert_eq!(
+ result,
+ vec![Some("a|c".to_string()), Some("b".to_string()), None,]
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn groups_with_filter() -> Result<()> {
+ let mut acc = make_groups_acc(",");
+
+ let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "b",
"c", "d"]));
+ let group_indices = vec![0, 0, 1, 1];
+ // Filter: only rows 0 and 3 are included
+ let filter = BooleanArray::from(vec![true, false, false, true]);
+ acc.update_batch(&[values], &group_indices, Some(&filter), 2)?;
+
+ let result = evaluate_groups(&mut acc, EmitTo::All);
+ assert_eq!(result, vec![Some("a".to_string()), Some("d".to_string())]);
+ Ok(())
+ }
+
+ #[test]
+ fn groups_emit_first() -> Result<()> {
+ let mut acc = make_groups_acc(",");
+
+ let values: ArrayRef =
+ Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d", "e",
"f"]));
+ let group_indices = vec![0, 1, 2, 0, 1, 2];
+ acc.update_batch(&[values], &group_indices, None, 3)?;
+
+ // Emit only the first 2 groups
+ let result = evaluate_groups(&mut acc, EmitTo::First(2));
+ assert_eq!(
+ result,
+ vec![Some("a,d".to_string()), Some("b,e".to_string())]
+ );
+
+ // Group 2 (now shifted to index 0) should still be intact
+ let result = evaluate_groups(&mut acc, EmitTo::All);
+ assert_eq!(result, vec![Some("c,f".to_string())]);
+ Ok(())
+ }
+
+ #[test]
+ fn groups_merge_batch() -> Result<()> {
+ let mut acc = make_groups_acc(",");
+
+ // First batch: group 0 = "a", group 1 = "b"
+ let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a",
"b"]));
+ acc.update_batch(&[values], &[0, 1], None, 2)?;
+
+ // Simulate a second accumulator's state (LargeUtf8 partial strings)
+ let partial_state: ArrayRef =
Arc::new(LargeStringArray::from(vec!["c,d", "e"]));
+ acc.merge_batch(&[partial_state], &[0, 1], None, 2)?;
+
+ let result = evaluate_groups(&mut acc, EmitTo::All);
+ assert_eq!(
+ result,
+ vec![Some("a,c,d".to_string()), Some("b,e".to_string())]
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn groups_empty_groups() -> Result<()> {
+ let mut acc = make_groups_acc(",");
+
+ // 4 groups total, but only groups 0 and 2 receive values
+ let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a",
"b"]));
+ acc.update_batch(&[values], &[0, 2], None, 4)?;
+
+ let result = evaluate_groups(&mut acc, EmitTo::All);
+ assert_eq!(
+ result,
+ vec![
+ Some("a".to_string()),
+ None, // group 1: never received a value
+ Some("b".to_string()),
+ None, // group 3: never received a value
+ ]
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn groups_multiple_batches() -> Result<()> {
+ let mut acc = make_groups_acc("|");
+
+ // Batch 1: 2 groups
+ let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a",
"b"]));
+ acc.update_batch(&[values], &[0, 1], None, 2)?;
+
+ // Batch 2: same groups, plus a new group
+ let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["c", "d",
"e"]));
+ acc.update_batch(&[values], &[0, 1, 2], None, 3)?;
+
+ let result = evaluate_groups(&mut acc, EmitTo::All);
+ assert_eq!(
+ result,
+ vec![
+ Some("a|c".to_string()),
+ Some("b|d".to_string()),
+ Some("e".to_string()),
+ ]
+ );
+ Ok(())
+ }
}
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index 1f2a81d334..e42ebd4ce7 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -6991,7 +6991,7 @@ SELECT STRING_AGG(DISTINCT x,'|' ORDER BY x) FROM strings
----
a|b|i|j|p|x|y|z
-query error This feature is not implemented: The second argument of the
string_agg function must be a string literal
+query error This feature is not implemented: string_agg delimiter must be a
string literal
SELECT STRING_AGG(DISTINCT x,y) FROM strings
query error Execution error: In an aggregate with DISTINCT, ORDER BY
expressions must appear in argument list
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]