This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 646f40a443 Implement special min/max accumulator for Strings and
Binary (10% faster for Clickbench Q28) (#12792)
646f40a443 is described below
commit 646f40a44330cdcfad5fc779897046d1dc0b83c5
Author: Andrew Lamb <[email protected]>
AuthorDate: Sun Oct 13 08:07:12 2024 -0400
Implement special min/max accumulator for Strings and Binary (10% faster
for Clickbench Q28) (#12792)
* Implement special min/max accumulator for Strings:
`MinMaxBytesAccumulator`
* fix bug
* fix msrv
* move code, handle filters
* simplify
* Add functional tests
* remove unecessary test
* improve docs
* improve docs
* cleanup
* improve comments
* fix diagram
* fix accounting
* Use correct type in memory accounting
* Add TODO comment
---
.../src/aggregate/groups_accumulator/accumulate.rs | 2 +-
.../src/aggregate/groups_accumulator/nulls.rs | 115 ++++-
datafusion/functions-aggregate/src/min_max.rs | 123 ++---
.../src/min_max/min_max_bytes.rs | 515 +++++++++++++++++++++
datafusion/sqllogictest/test_files/aggregate.slt | 174 +++++++
5 files changed, 872 insertions(+), 57 deletions(-)
diff --git
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
index a0475fe8e4..3efd348937 100644
---
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
+++
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
@@ -95,7 +95,7 @@ impl NullState {
///
/// When value_fn is called it also sets
///
- /// 1. `self.seen_values[group_index]` to true for all rows that had a non
null vale
+ /// 1. `self.seen_values[group_index]` to true for all rows that had a non
null value
pub fn accumulate<T, F>(
&mut self,
group_indices: &[usize],
diff --git
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs
index 25212f7f0f..6a8946034c 100644
---
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs
+++
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs
@@ -15,13 +15,22 @@
// specific language governing permissions and limitations
// under the License.
-//! [`set_nulls`], and [`filtered_null_mask`], utilities for working with nulls
+//! [`set_nulls`], other utilities for working with nulls
-use arrow::array::{Array, ArrowNumericType, BooleanArray, PrimitiveArray};
+use arrow::array::{
+ Array, ArrayRef, ArrowNumericType, AsArray, BinaryArray, BinaryViewArray,
+ BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray,
StringArray,
+ StringViewArray,
+};
use arrow::buffer::NullBuffer;
+use arrow::datatypes::DataType;
+use datafusion_common::{not_impl_err, Result};
+use std::sync::Arc;
/// Sets the validity mask for a `PrimitiveArray` to `nulls`
/// replacing any existing null mask
+///
+/// See [`set_nulls_dyn`] for a version that works with `Array`
pub fn set_nulls<T: ArrowNumericType + Send>(
array: PrimitiveArray<T>,
nulls: Option<NullBuffer>,
@@ -91,3 +100,105 @@ pub fn filtered_null_mask(
let opt_filter = opt_filter.and_then(filter_to_nulls);
NullBuffer::union(opt_filter.as_ref(), input.nulls())
}
+
+/// Applies optional filter to input, returning a new array of the same type
+/// with the same data, but with any values that were filtered out set to null
+pub fn apply_filter_as_nulls(
+ input: &dyn Array,
+ opt_filter: Option<&BooleanArray>,
+) -> Result<ArrayRef> {
+ let nulls = filtered_null_mask(opt_filter, input);
+ set_nulls_dyn(input, nulls)
+}
+
+/// Replaces the nulls in the input array with the given `NullBuffer`
+///
+/// TODO: replace when upstreamed in arrow-rs:
<https://github.com/apache/arrow-rs/issues/6528>
+pub fn set_nulls_dyn(input: &dyn Array, nulls: Option<NullBuffer>) ->
Result<ArrayRef> {
+ if let Some(nulls) = nulls.as_ref() {
+ assert_eq!(nulls.len(), input.len());
+ }
+
+ let output: ArrayRef = match input.data_type() {
+ DataType::Utf8 => {
+ let input = input.as_string::<i32>();
+ // safety: values / offsets came from a valid string array, so are
valid utf8
+ // and we checked nulls has the same length as values
+ unsafe {
+ Arc::new(StringArray::new_unchecked(
+ input.offsets().clone(),
+ input.values().clone(),
+ nulls,
+ ))
+ }
+ }
+ DataType::LargeUtf8 => {
+ let input = input.as_string::<i64>();
+ // safety: values / offsets came from a valid string array, so are
valid utf8
+ // and we checked nulls has the same length as values
+ unsafe {
+ Arc::new(LargeStringArray::new_unchecked(
+ input.offsets().clone(),
+ input.values().clone(),
+ nulls,
+ ))
+ }
+ }
+ DataType::Utf8View => {
+ let input = input.as_string_view();
+ // safety: values / views came from a valid string view array, so
are valid utf8
+ // and we checked nulls has the same length as values
+ unsafe {
+ Arc::new(StringViewArray::new_unchecked(
+ input.views().clone(),
+ input.data_buffers().to_vec(),
+ nulls,
+ ))
+ }
+ }
+
+ DataType::Binary => {
+ let input = input.as_binary::<i32>();
+ // safety: values / offsets came from a valid binary array
+ // and we checked nulls has the same length as values
+ unsafe {
+ Arc::new(BinaryArray::new_unchecked(
+ input.offsets().clone(),
+ input.values().clone(),
+ nulls,
+ ))
+ }
+ }
+ DataType::LargeBinary => {
+ let input = input.as_binary::<i64>();
+ // safety: values / offsets came from a valid large binary array
+ // and we checked nulls has the same length as values
+ unsafe {
+ Arc::new(LargeBinaryArray::new_unchecked(
+ input.offsets().clone(),
+ input.values().clone(),
+ nulls,
+ ))
+ }
+ }
+ DataType::BinaryView => {
+ let input = input.as_binary_view();
+ // safety: values / views came from a valid binary view array
+ // and we checked nulls has the same length as values
+ unsafe {
+ Arc::new(BinaryViewArray::new_unchecked(
+ input.views().clone(),
+ input.data_buffers().to_vec(),
+ nulls,
+ ))
+ }
+ }
+ _ => {
+ return not_impl_err!("Applying nulls {:?}", input.data_type());
+ }
+ };
+ assert_eq!(input.len(), output.len());
+ assert_eq!(input.data_type(), output.data_type());
+
+ Ok(output)
+}
diff --git a/datafusion/functions-aggregate/src/min_max.rs
b/datafusion/functions-aggregate/src/min_max.rs
index 3d2915fd09..2f7954a8ee 100644
--- a/datafusion/functions-aggregate/src/min_max.rs
+++ b/datafusion/functions-aggregate/src/min_max.rs
@@ -17,6 +17,8 @@
//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function
//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function
+mod min_max_bytes;
+
use arrow::array::{
ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array,
Date64Array,
Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array,
@@ -50,6 +52,7 @@ use arrow::datatypes::{
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
};
+use crate::min_max::min_max_bytes::MinMaxBytesAccumulator;
use datafusion_common::ScalarValue;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation,
Signature,
@@ -104,7 +107,7 @@ impl Default for Max {
/// the specified [`ArrowPrimitiveType`].
///
/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
-macro_rules! instantiate_max_accumulator {
+macro_rules! primitive_max_accumulator {
($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
Ok(Box::new(
PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur,
new| {
@@ -123,7 +126,7 @@ macro_rules! instantiate_max_accumulator {
///
///
/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
-macro_rules! instantiate_min_accumulator {
+macro_rules! primitive_min_accumulator {
($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
Ok(Box::new(
PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur,
new| {
@@ -231,6 +234,12 @@ impl AggregateUDFImpl for Max {
| Time32(_)
| Time64(_)
| Timestamp(_, _)
+ | Utf8
+ | LargeUtf8
+ | Utf8View
+ | Binary
+ | LargeBinary
+ | BinaryView
)
}
@@ -242,58 +251,58 @@ impl AggregateUDFImpl for Max {
use TimeUnit::*;
let data_type = args.return_type;
match data_type {
- Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type),
- Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type),
- Int32 => instantiate_max_accumulator!(data_type, i32, Int32Type),
- Int64 => instantiate_max_accumulator!(data_type, i64, Int64Type),
- UInt8 => instantiate_max_accumulator!(data_type, u8, UInt8Type),
- UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type),
- UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type),
- UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type),
+ Int8 => primitive_max_accumulator!(data_type, i8, Int8Type),
+ Int16 => primitive_max_accumulator!(data_type, i16, Int16Type),
+ Int32 => primitive_max_accumulator!(data_type, i32, Int32Type),
+ Int64 => primitive_max_accumulator!(data_type, i64, Int64Type),
+ UInt8 => primitive_max_accumulator!(data_type, u8, UInt8Type),
+ UInt16 => primitive_max_accumulator!(data_type, u16, UInt16Type),
+ UInt32 => primitive_max_accumulator!(data_type, u32, UInt32Type),
+ UInt64 => primitive_max_accumulator!(data_type, u64, UInt64Type),
Float16 => {
- instantiate_max_accumulator!(data_type, f16, Float16Type)
+ primitive_max_accumulator!(data_type, f16, Float16Type)
}
Float32 => {
- instantiate_max_accumulator!(data_type, f32, Float32Type)
+ primitive_max_accumulator!(data_type, f32, Float32Type)
}
Float64 => {
- instantiate_max_accumulator!(data_type, f64, Float64Type)
+ primitive_max_accumulator!(data_type, f64, Float64Type)
}
- Date32 => instantiate_max_accumulator!(data_type, i32, Date32Type),
- Date64 => instantiate_max_accumulator!(data_type, i64, Date64Type),
+ Date32 => primitive_max_accumulator!(data_type, i32, Date32Type),
+ Date64 => primitive_max_accumulator!(data_type, i64, Date64Type),
Time32(Second) => {
- instantiate_max_accumulator!(data_type, i32, Time32SecondType)
+ primitive_max_accumulator!(data_type, i32, Time32SecondType)
}
Time32(Millisecond) => {
- instantiate_max_accumulator!(data_type, i32,
Time32MillisecondType)
+ primitive_max_accumulator!(data_type, i32,
Time32MillisecondType)
}
Time64(Microsecond) => {
- instantiate_max_accumulator!(data_type, i64,
Time64MicrosecondType)
+ primitive_max_accumulator!(data_type, i64,
Time64MicrosecondType)
}
Time64(Nanosecond) => {
- instantiate_max_accumulator!(data_type, i64,
Time64NanosecondType)
+ primitive_max_accumulator!(data_type, i64,
Time64NanosecondType)
}
Timestamp(Second, _) => {
- instantiate_max_accumulator!(data_type, i64,
TimestampSecondType)
+ primitive_max_accumulator!(data_type, i64, TimestampSecondType)
}
Timestamp(Millisecond, _) => {
- instantiate_max_accumulator!(data_type, i64,
TimestampMillisecondType)
+ primitive_max_accumulator!(data_type, i64,
TimestampMillisecondType)
}
Timestamp(Microsecond, _) => {
- instantiate_max_accumulator!(data_type, i64,
TimestampMicrosecondType)
+ primitive_max_accumulator!(data_type, i64,
TimestampMicrosecondType)
}
Timestamp(Nanosecond, _) => {
- instantiate_max_accumulator!(data_type, i64,
TimestampNanosecondType)
+ primitive_max_accumulator!(data_type, i64,
TimestampNanosecondType)
}
Decimal128(_, _) => {
- instantiate_max_accumulator!(data_type, i128, Decimal128Type)
+ primitive_max_accumulator!(data_type, i128, Decimal128Type)
}
Decimal256(_, _) => {
- instantiate_max_accumulator!(data_type, i256, Decimal256Type)
+ primitive_max_accumulator!(data_type, i256, Decimal256Type)
+ }
+ Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView =>
{
+
Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone())))
}
-
- // It would be nice to have a fast implementation for Strings as
well
- // https://github.com/apache/datafusion/issues/6906
// This is only reached if groups_accumulator_supported is out of
sync
_ => internal_err!("GroupsAccumulator not supported for max({})",
data_type),
@@ -1057,6 +1066,12 @@ impl AggregateUDFImpl for Min {
| Time32(_)
| Time64(_)
| Timestamp(_, _)
+ | Utf8
+ | LargeUtf8
+ | Utf8View
+ | Binary
+ | LargeBinary
+ | BinaryView
)
}
@@ -1068,58 +1083,58 @@ impl AggregateUDFImpl for Min {
use TimeUnit::*;
let data_type = args.return_type;
match data_type {
- Int8 => instantiate_min_accumulator!(data_type, i8, Int8Type),
- Int16 => instantiate_min_accumulator!(data_type, i16, Int16Type),
- Int32 => instantiate_min_accumulator!(data_type, i32, Int32Type),
- Int64 => instantiate_min_accumulator!(data_type, i64, Int64Type),
- UInt8 => instantiate_min_accumulator!(data_type, u8, UInt8Type),
- UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type),
- UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type),
- UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type),
+ Int8 => primitive_min_accumulator!(data_type, i8, Int8Type),
+ Int16 => primitive_min_accumulator!(data_type, i16, Int16Type),
+ Int32 => primitive_min_accumulator!(data_type, i32, Int32Type),
+ Int64 => primitive_min_accumulator!(data_type, i64, Int64Type),
+ UInt8 => primitive_min_accumulator!(data_type, u8, UInt8Type),
+ UInt16 => primitive_min_accumulator!(data_type, u16, UInt16Type),
+ UInt32 => primitive_min_accumulator!(data_type, u32, UInt32Type),
+ UInt64 => primitive_min_accumulator!(data_type, u64, UInt64Type),
Float16 => {
- instantiate_min_accumulator!(data_type, f16, Float16Type)
+ primitive_min_accumulator!(data_type, f16, Float16Type)
}
Float32 => {
- instantiate_min_accumulator!(data_type, f32, Float32Type)
+ primitive_min_accumulator!(data_type, f32, Float32Type)
}
Float64 => {
- instantiate_min_accumulator!(data_type, f64, Float64Type)
+ primitive_min_accumulator!(data_type, f64, Float64Type)
}
- Date32 => instantiate_min_accumulator!(data_type, i32, Date32Type),
- Date64 => instantiate_min_accumulator!(data_type, i64, Date64Type),
+ Date32 => primitive_min_accumulator!(data_type, i32, Date32Type),
+ Date64 => primitive_min_accumulator!(data_type, i64, Date64Type),
Time32(Second) => {
- instantiate_min_accumulator!(data_type, i32, Time32SecondType)
+ primitive_min_accumulator!(data_type, i32, Time32SecondType)
}
Time32(Millisecond) => {
- instantiate_min_accumulator!(data_type, i32,
Time32MillisecondType)
+ primitive_min_accumulator!(data_type, i32,
Time32MillisecondType)
}
Time64(Microsecond) => {
- instantiate_min_accumulator!(data_type, i64,
Time64MicrosecondType)
+ primitive_min_accumulator!(data_type, i64,
Time64MicrosecondType)
}
Time64(Nanosecond) => {
- instantiate_min_accumulator!(data_type, i64,
Time64NanosecondType)
+ primitive_min_accumulator!(data_type, i64,
Time64NanosecondType)
}
Timestamp(Second, _) => {
- instantiate_min_accumulator!(data_type, i64,
TimestampSecondType)
+ primitive_min_accumulator!(data_type, i64, TimestampSecondType)
}
Timestamp(Millisecond, _) => {
- instantiate_min_accumulator!(data_type, i64,
TimestampMillisecondType)
+ primitive_min_accumulator!(data_type, i64,
TimestampMillisecondType)
}
Timestamp(Microsecond, _) => {
- instantiate_min_accumulator!(data_type, i64,
TimestampMicrosecondType)
+ primitive_min_accumulator!(data_type, i64,
TimestampMicrosecondType)
}
Timestamp(Nanosecond, _) => {
- instantiate_min_accumulator!(data_type, i64,
TimestampNanosecondType)
+ primitive_min_accumulator!(data_type, i64,
TimestampNanosecondType)
}
Decimal128(_, _) => {
- instantiate_min_accumulator!(data_type, i128, Decimal128Type)
+ primitive_min_accumulator!(data_type, i128, Decimal128Type)
}
Decimal256(_, _) => {
- instantiate_min_accumulator!(data_type, i256, Decimal256Type)
+ primitive_min_accumulator!(data_type, i256, Decimal256Type)
+ }
+ Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView =>
{
+
Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone())))
}
-
- // It would be nice to have a fast implementation for Strings as
well
- // https://github.com/apache/datafusion/issues/6906
// This is only reached if groups_accumulator_supported is out of
sync
_ => internal_err!("GroupsAccumulator not supported for min({})",
data_type),
diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs
b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs
new file mode 100644
index 0000000000..e3f01b91bf
--- /dev/null
+++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs
@@ -0,0 +1,515 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::{
+ Array, ArrayRef, AsArray, BinaryBuilder, BinaryViewBuilder, BooleanArray,
+ LargeBinaryBuilder, LargeStringBuilder, StringBuilder, StringViewBuilder,
+};
+use arrow_schema::DataType;
+use datafusion_common::{internal_err, Result};
+use datafusion_expr::{EmitTo, GroupsAccumulator};
+use
datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls;
+use std::sync::Arc;
+
+/// Implements fast Min/Max [`GroupsAccumulator`] for "bytes" types
([`StringArray`],
+/// [`BinaryArray`], [`StringViewArray`], etc)
+///
+/// This implementation dispatches to the appropriate specialized code in
+/// [`MinMaxBytesState`] based on data type and comparison function
+///
+/// [`StringArray`]: arrow::array::StringArray
+/// [`BinaryArray`]: arrow::array::BinaryArray
+/// [`StringViewArray`]: arrow::array::StringViewArray
+#[derive(Debug)]
+pub(crate) struct MinMaxBytesAccumulator {
+ /// Inner data storage.
+ inner: MinMaxBytesState,
+ /// if true, is `MIN` otherwise is `MAX`
+ is_min: bool,
+}
+
+impl MinMaxBytesAccumulator {
+ /// Create a new accumulator for computing `min(val)`
+ pub fn new_min(data_type: DataType) -> Self {
+ Self {
+ inner: MinMaxBytesState::new(data_type),
+ is_min: true,
+ }
+ }
+
+ /// Create a new accumulator fo computing `max(val)`
+ pub fn new_max(data_type: DataType) -> Self {
+ Self {
+ inner: MinMaxBytesState::new(data_type),
+ is_min: false,
+ }
+ }
+}
+
+impl GroupsAccumulator for MinMaxBytesAccumulator {
+ fn update_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ let array = &values[0];
+ assert_eq!(array.len(), group_indices.len());
+ assert_eq!(array.data_type(), &self.inner.data_type);
+
+ // apply filter if needed
+ let array = apply_filter_as_nulls(array, opt_filter)?;
+
+ // dispatch to appropriate kernel / specialized implementation
+ fn string_min(a: &[u8], b: &[u8]) -> bool {
+ // safety: only called from this function, which ensures a and b
come
+ // from an array with valid utf8 data
+ unsafe {
+ let a = std::str::from_utf8_unchecked(a);
+ let b = std::str::from_utf8_unchecked(b);
+ a < b
+ }
+ }
+ fn string_max(a: &[u8], b: &[u8]) -> bool {
+ // safety: only called from this function, which ensures a and b
come
+ // from an array with valid utf8 data
+ unsafe {
+ let a = std::str::from_utf8_unchecked(a);
+ let b = std::str::from_utf8_unchecked(b);
+ a > b
+ }
+ }
+ fn binary_min(a: &[u8], b: &[u8]) -> bool {
+ a < b
+ }
+
+ fn binary_max(a: &[u8], b: &[u8]) -> bool {
+ a > b
+ }
+
+ fn str_to_bytes<'a>(
+ it: impl Iterator<Item = Option<&'a str>>,
+ ) -> impl Iterator<Item = Option<&'a [u8]>> {
+ it.map(|s| s.map(|s| s.as_bytes()))
+ }
+
+ match (self.is_min, &self.inner.data_type) {
+ // Utf8/LargeUtf8/Utf8View Min
+ (true, &DataType::Utf8) => self.inner.update_batch(
+ str_to_bytes(array.as_string::<i32>().iter()),
+ group_indices,
+ total_num_groups,
+ string_min,
+ ),
+ (true, &DataType::LargeUtf8) => self.inner.update_batch(
+ str_to_bytes(array.as_string::<i64>().iter()),
+ group_indices,
+ total_num_groups,
+ string_min,
+ ),
+ (true, &DataType::Utf8View) => self.inner.update_batch(
+ str_to_bytes(array.as_string_view().iter()),
+ group_indices,
+ total_num_groups,
+ string_min,
+ ),
+
+ // Utf8/LargeUtf8/Utf8View Max
+ (false, &DataType::Utf8) => self.inner.update_batch(
+ str_to_bytes(array.as_string::<i32>().iter()),
+ group_indices,
+ total_num_groups,
+ string_max,
+ ),
+ (false, &DataType::LargeUtf8) => self.inner.update_batch(
+ str_to_bytes(array.as_string::<i64>().iter()),
+ group_indices,
+ total_num_groups,
+ string_max,
+ ),
+ (false, &DataType::Utf8View) => self.inner.update_batch(
+ str_to_bytes(array.as_string_view().iter()),
+ group_indices,
+ total_num_groups,
+ string_max,
+ ),
+
+ // Binary/LargeBinary/BinaryView Min
+ (true, &DataType::Binary) => self.inner.update_batch(
+ array.as_binary::<i32>().iter(),
+ group_indices,
+ total_num_groups,
+ binary_min,
+ ),
+ (true, &DataType::LargeBinary) => self.inner.update_batch(
+ array.as_binary::<i64>().iter(),
+ group_indices,
+ total_num_groups,
+ binary_min,
+ ),
+ (true, &DataType::BinaryView) => self.inner.update_batch(
+ array.as_binary_view().iter(),
+ group_indices,
+ total_num_groups,
+ binary_min,
+ ),
+
+ // Binary/LargeBinary/BinaryView Max
+ (false, &DataType::Binary) => self.inner.update_batch(
+ array.as_binary::<i32>().iter(),
+ group_indices,
+ total_num_groups,
+ binary_max,
+ ),
+ (false, &DataType::LargeBinary) => self.inner.update_batch(
+ array.as_binary::<i64>().iter(),
+ group_indices,
+ total_num_groups,
+ binary_max,
+ ),
+ (false, &DataType::BinaryView) => self.inner.update_batch(
+ array.as_binary_view().iter(),
+ group_indices,
+ total_num_groups,
+ binary_max,
+ ),
+
+ _ => internal_err!(
+ "Unexpected combination for MinMaxBytesAccumulator: ({:?},
{:?})",
+ self.is_min,
+ self.inner.data_type
+ ),
+ }
+ }
+
+ fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
+ let (data_capacity, min_maxes) = self.inner.emit_to(emit_to);
+
+ // Convert the Vec of bytes to a vec of Strings (at no cost)
+ fn bytes_to_str(
+ min_maxes: Vec<Option<Vec<u8>>>,
+ ) -> impl Iterator<Item = Option<String>> {
+ min_maxes.into_iter().map(|opt| {
+ opt.map(|bytes| {
+ // Safety: only called on data added from update_batch
which ensures
+ // the input type matched the output type
+ unsafe { String::from_utf8_unchecked(bytes) }
+ })
+ })
+ }
+
+ let result: ArrayRef = match self.inner.data_type {
+ DataType::Utf8 => {
+ let mut builder =
+ StringBuilder::with_capacity(min_maxes.len(),
data_capacity);
+ for opt in bytes_to_str(min_maxes) {
+ match opt {
+ None => builder.append_null(),
+ Some(s) => builder.append_value(s.as_str()),
+ }
+ }
+ Arc::new(builder.finish())
+ }
+ DataType::LargeUtf8 => {
+ let mut builder =
+ LargeStringBuilder::with_capacity(min_maxes.len(),
data_capacity);
+ for opt in bytes_to_str(min_maxes) {
+ match opt {
+ None => builder.append_null(),
+ Some(s) => builder.append_value(s.as_str()),
+ }
+ }
+ Arc::new(builder.finish())
+ }
+ DataType::Utf8View => {
+ let block_size = capacity_to_view_block_size(data_capacity);
+
+ let mut builder =
StringViewBuilder::with_capacity(min_maxes.len())
+ .with_fixed_block_size(block_size);
+ for opt in bytes_to_str(min_maxes) {
+ match opt {
+ None => builder.append_null(),
+ Some(s) => builder.append_value(s.as_str()),
+ }
+ }
+ Arc::new(builder.finish())
+ }
+ DataType::Binary => {
+ let mut builder =
+ BinaryBuilder::with_capacity(min_maxes.len(),
data_capacity);
+ for opt in min_maxes {
+ match opt {
+ None => builder.append_null(),
+ Some(s) => builder.append_value(s.as_ref() as &[u8]),
+ }
+ }
+ Arc::new(builder.finish())
+ }
+ DataType::LargeBinary => {
+ let mut builder =
+ LargeBinaryBuilder::with_capacity(min_maxes.len(),
data_capacity);
+ for opt in min_maxes {
+ match opt {
+ None => builder.append_null(),
+ Some(s) => builder.append_value(s.as_ref() as &[u8]),
+ }
+ }
+ Arc::new(builder.finish())
+ }
+ DataType::BinaryView => {
+ let block_size = capacity_to_view_block_size(data_capacity);
+
+ let mut builder =
BinaryViewBuilder::with_capacity(min_maxes.len())
+ .with_fixed_block_size(block_size);
+ for opt in min_maxes {
+ match opt {
+ None => builder.append_null(),
+ Some(s) => builder.append_value(s.as_ref() as &[u8]),
+ }
+ }
+ Arc::new(builder.finish())
+ }
+ _ => {
+ return internal_err!(
+ "Unexpected data type for MinMaxBytesAccumulator: {:?}",
+ self.inner.data_type
+ );
+ }
+ };
+
+ assert_eq!(&self.inner.data_type, result.data_type());
+ Ok(result)
+ }
+
+ fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
+ // min/max are their own states (no transition needed)
+ self.evaluate(emit_to).map(|arr| vec![arr])
+ }
+
+ fn merge_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ // min/max are their own states (no transition needed)
+ self.update_batch(values, group_indices, opt_filter, total_num_groups)
+ }
+
+ fn convert_to_state(
+ &self,
+ values: &[ArrayRef],
+ opt_filter: Option<&BooleanArray>,
+ ) -> Result<Vec<ArrayRef>> {
+ // Min/max do not change the values as they are their own states
+ // apply the filter by combining with the null mask, if any
+ let output = apply_filter_as_nulls(&values[0], opt_filter)?;
+ Ok(vec![output])
+ }
+
+ fn supports_convert_to_state(&self) -> bool {
+ true
+ }
+
+ fn size(&self) -> usize {
+ self.inner.size()
+ }
+}
+
+/// Returns the block size in (contiguous buffer size) to use
+/// for a given data capacity (total string length)
+///
+/// This is a heuristic to avoid allocating too many small buffers
+fn capacity_to_view_block_size(data_capacity: usize) -> u32 {
+ let max_block_size = 2 * 1024 * 1024;
+ if let Ok(block_size) = u32::try_from(data_capacity) {
+ block_size.min(max_block_size)
+ } else {
+ max_block_size
+ }
+}
+
+/// Stores internal Min/Max state for "bytes" types.
+///
+/// This implementation is general and stores the minimum/maximum for each
+/// groups in an individual byte array, which balances allocations and memory
+/// fragmentation (aka garbage).
+///
+/// ```text
+/// ┌─────────────────────────────────┐
+/// ┌─────┐ ┌────▶│Option<Vec<u8>> (["A"]) │───────────▶ "A"
+/// │ 0 │────┘ └─────────────────────────────────┘
+/// ├─────┤ ┌─────────────────────────────────┐
+/// │ 1 │─────────▶│Option<Vec<u8>> (["Z"]) │───────────▶ "Z"
+/// └─────┘ └─────────────────────────────────┘ ...
+/// ... ...
+/// ┌─────┐ ┌────────────────────────────────┐
+/// │ N-2 │─────────▶│Option<Vec<u8>> (["A"]) │────────────▶ "A"
+/// ├─────┤ └────────────────────────────────┘
+/// │ N-1 │────┐ ┌────────────────────────────────┐
+/// └─────┘ └────▶│Option<Vec<u8>> (["Q"]) │────────────▶ "Q"
+/// └────────────────────────────────┘
+///
+/// min_max: Vec<Option<Vec<u8>>
+/// ```
+///
+/// Note that for `StringViewArray` and `BinaryViewArray`, there are
potentially
+/// more efficient implementations (e.g. by managing a string data buffer
+/// directly), but then garbage collection, memory management, and final array
+/// construction becomes more complex.
+///
+/// See discussion on <https://github.com/apache/datafusion/issues/6906>
+#[derive(Debug)]
+struct MinMaxBytesState {
+ /// The minimum/maximum value for each group
+ min_max: Vec<Option<Vec<u8>>>,
+ /// The data type of the array
+ data_type: DataType,
+ /// The total bytes of the string data (for pre-allocating the final array,
+ /// and tracking memory usage)
+ total_data_bytes: usize,
+}
+
+#[derive(Debug, Clone, Copy)]
+enum MinMaxLocation<'a> {
+ /// the min/max value is stored in the existing `min_max` array
+ ExistingMinMax,
+ /// the min/max value is stored in the input array at the given index
+ Input(&'a [u8]),
+}
+
+/// Implement the MinMaxBytesAccumulator with a comparison function
+/// for comparing strings
+impl MinMaxBytesState {
+ /// Create a new MinMaxBytesAccumulator
+ ///
+ /// # Arguments:
+ /// * `data_type`: The data type of the arrays that will be passed to this
accumulator
+ fn new(data_type: DataType) -> Self {
+ Self {
+ min_max: vec![],
+ data_type,
+ total_data_bytes: 0,
+ }
+ }
+
+ /// Set the specified group to the given value, updating memory usage
appropriately
+ fn set_value(&mut self, group_index: usize, new_val: &[u8]) {
+ match self.min_max[group_index].as_mut() {
+ None => {
+ self.min_max[group_index] = Some(new_val.to_vec());
+ self.total_data_bytes += new_val.len();
+ }
+ Some(existing_val) => {
+ // Copy data over to avoid re-allocating
+ self.total_data_bytes -= existing_val.len();
+ self.total_data_bytes += new_val.len();
+ existing_val.clear();
+ existing_val.extend_from_slice(new_val);
+ }
+ }
+ }
+
+ /// Updates the min/max values for the given string values
+ ///
+ /// `cmp` is the comparison function to use, called like `cmp(new_val,
existing_val)`
+ /// returns true if the `new_val` should replace `existing_val`
+ fn update_batch<'a, F, I>(
+ &mut self,
+ iter: I,
+ group_indices: &[usize],
+ total_num_groups: usize,
+ mut cmp: F,
+ ) -> Result<()>
+ where
+ F: FnMut(&[u8], &[u8]) -> bool + Send + Sync,
+ I: IntoIterator<Item = Option<&'a [u8]>>,
+ {
+ self.min_max.resize(total_num_groups, None);
+ // Minimize value copies by calculating the new min/maxes for each
group
+ // in this batch (either the existing min/max or the new input value)
+ // and updating the owne values in `self.min_maxes` at most once
+ let mut locations = vec![MinMaxLocation::ExistingMinMax;
total_num_groups];
+
+ // Figure out the new min value for each group
+ for (new_val, group_index) in
iter.into_iter().zip(group_indices.iter()) {
+ let group_index = *group_index;
+ let Some(new_val) = new_val else {
+ continue; // skip nulls
+ };
+
+ let existing_val = match locations[group_index] {
+ // previous input value was the min/max, so compare it
+ MinMaxLocation::Input(existing_val) => existing_val,
+ MinMaxLocation::ExistingMinMax => {
+ let Some(exising_val) = self.min_max[group_index].as_ref()
else {
+ // no existing min/max, so this is the new min/max
+ locations[group_index] =
MinMaxLocation::Input(new_val);
+ continue;
+ };
+ exising_val.as_ref()
+ }
+ };
+
+ // Compare the new value to the existing value, replacing if
necessary
+ if cmp(new_val, existing_val) {
+ locations[group_index] = MinMaxLocation::Input(new_val);
+ }
+ }
+
+ // Update self.min_max with any new min/max values we found in the
input
+ for (group_index, location) in locations.iter().enumerate() {
+ match location {
+ MinMaxLocation::ExistingMinMax => {}
+ MinMaxLocation::Input(new_val) => self.set_value(group_index,
new_val),
+ }
+ }
+ Ok(())
+ }
+
+ /// Emits the specified min_max values
+ ///
+ /// Returns (data_capacity, min_maxes), updating the current value of
total_data_bytes
+ ///
+ /// - `data_capacity`: the total length of all strings and their contents,
+ /// - `min_maxes`: the actual min/max values for each group
+ fn emit_to(&mut self, emit_to: EmitTo) -> (usize, Vec<Option<Vec<u8>>>) {
+ match emit_to {
+ EmitTo::All => {
+ (
+ std::mem::take(&mut self.total_data_bytes), // reset total
bytes and min_max
+ std::mem::take(&mut self.min_max),
+ )
+ }
+ EmitTo::First(n) => {
+ let first_min_maxes: Vec<_> =
self.min_max.drain(..n).collect();
+ let first_data_capacity: usize = first_min_maxes
+ .iter()
+ .map(|opt| opt.as_ref().map(|s| s.len()).unwrap_or(0))
+ .sum();
+ self.total_data_bytes -= first_data_capacity;
+ (first_data_capacity, first_min_maxes)
+ }
+ }
+ }
+
+ fn size(&self) -> usize {
+ self.total_data_bytes
+ + self.min_max.len() * std::mem::size_of::<Option<Vec<u8>>>()
+ }
+}
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index ce382a9bf8..f03c3700ab 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -3818,6 +3818,180 @@ DROP TABLE min_bool;
# Min_Max End #
#################
+
+
+#################
+# min_max on strings/binary with null values and groups
+#################
+
+statement ok
+CREATE TABLE strings (value TEXT, id int);
+
+statement ok
+INSERT INTO strings VALUES
+ ('c', 1),
+ ('d', 1),
+ ('a', 3),
+ ('c', 1),
+ ('b', 1),
+ (NULL, 1),
+ (NULL, 4),
+ ('d', 1),
+ ('z', 2),
+ ('c', 1),
+ ('a', 2);
+
+############ Utf8 ############
+
+query IT
+SELECT id, MIN(value) FROM strings GROUP BY id ORDER BY id;
+----
+1 b
+2 a
+3 a
+4 NULL
+
+query IT
+SELECT id, MAX(value) FROM strings GROUP BY id ORDER BY id;
+----
+1 d
+2 z
+3 a
+4 NULL
+
+############ LargeUtf8 ############
+
+statement ok
+CREATE VIEW large_strings AS SELECT id, arrow_cast(value, 'LargeUtf8') as
value FROM strings;
+
+
+query IT
+SELECT id, MIN(value) FROM large_strings GROUP BY id ORDER BY id;
+----
+1 b
+2 a
+3 a
+4 NULL
+
+query IT
+SELECT id, MAX(value) FROM large_strings GROUP BY id ORDER BY id;
+----
+1 d
+2 z
+3 a
+4 NULL
+
+statement ok
+DROP VIEW large_strings
+
+############ Utf8View ############
+
+statement ok
+CREATE VIEW string_views AS SELECT id, arrow_cast(value, 'Utf8View') as value
FROM strings;
+
+
+query IT
+SELECT id, MIN(value) FROM string_views GROUP BY id ORDER BY id;
+----
+1 b
+2 a
+3 a
+4 NULL
+
+query IT
+SELECT id, MAX(value) FROM string_views GROUP BY id ORDER BY id;
+----
+1 d
+2 z
+3 a
+4 NULL
+
+statement ok
+DROP VIEW string_views
+
+############ Binary ############
+
+statement ok
+CREATE VIEW binary AS SELECT id, arrow_cast(value, 'Binary') as value FROM
strings;
+
+
+query I?
+SELECT id, MIN(value) FROM binary GROUP BY id ORDER BY id;
+----
+1 62
+2 61
+3 61
+4 NULL
+
+query I?
+SELECT id, MAX(value) FROM binary GROUP BY id ORDER BY id;
+----
+1 64
+2 7a
+3 61
+4 NULL
+
+statement ok
+DROP VIEW binary
+
+############ LargeBinary ############
+
+statement ok
+CREATE VIEW large_binary AS SELECT id, arrow_cast(value, 'LargeBinary') as
value FROM strings;
+
+
+query I?
+SELECT id, MIN(value) FROM large_binary GROUP BY id ORDER BY id;
+----
+1 62
+2 61
+3 61
+4 NULL
+
+query I?
+SELECT id, MAX(value) FROM large_binary GROUP BY id ORDER BY id;
+----
+1 64
+2 7a
+3 61
+4 NULL
+
+statement ok
+DROP VIEW large_binary
+
+############ BinaryView ############
+
+statement ok
+CREATE VIEW binary_views AS SELECT id, arrow_cast(value, 'BinaryView') as
value FROM strings;
+
+
+query I?
+SELECT id, MIN(value) FROM binary_views GROUP BY id ORDER BY id;
+----
+1 62
+2 61
+3 61
+4 NULL
+
+query I?
+SELECT id, MAX(value) FROM binary_views GROUP BY id ORDER BY id;
+----
+1 64
+2 7a
+3 61
+4 NULL
+
+statement ok
+DROP VIEW binary_views
+
+statement ok
+DROP TABLE strings;
+
+#################
+# End min_max on strings/binary with null values and groups
+#################
+
+
statement ok
create table bool_aggregate_functions (
c1 boolean not null,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]