This is an automated email from the ASF dual-hosted git repository.
github-merge-queue[bot] pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 3c5361881a fix: Incorrect behavior for `FILTER` on NULLs (#22068)
3c5361881a is described below
commit 3c5361881a52f0fc673983cc0eee213723e12890
Author: Neil Conway <[email protected]>
AuthorDate: Tue May 12 00:52:18 2026 -0400
fix: Incorrect behavior for `FILTER` on NULLs (#22068)
## Which issue does this PR close?
- Closes #22067.
## Rationale for this change
In the grouping code, `accumulate_multiple` and `accumulate_indices`
take a `BooleanArray` parameter, which has the result of the aggregate's
`FILTER` clause (if any). Both functions only consider the value bits of
the array, not the NULL bitmap, which means they consider `NULL` filter
results to be effectively true, not false.
## What changes are included in this PR?
* Fix NULL handling in `accumulate_multiple` and `accumulate_indices`
* Refactor `accumulate_multiple` to be more readable and make use of
`NullBuffer::union_many`
* Introduce a new helper, `filter_to_validity`
* Optimize `filter_to_nulls` to use `filter_to_validity` and avoid
constructing an unnecessary intermediate `NullBuffer`
* Add unit tests for NULL handling in `accumulate_multiple` and
`accumulate_indices`
* Add SLT tests with SQL repros for both code paths
## Are these changes tested?
Yes, with new tests added.
## Are there any user-facing changes?
This changes query behavior for the affected (buggy) queries.
This PR also changes the signature of `filter_to_nulls`, which is
technically a public API.
---------
Co-authored-by: Matt Butrovich <[email protected]>
Co-authored-by: Daniƫl Heres <[email protected]>
---
.../src/aggregate/groups_accumulator/accumulate.rs | 114 ++++++++++++++++-----
.../src/aggregate/groups_accumulator/nulls.rs | 27 +++--
datafusion/functions-aggregate/src/array_agg.rs | 2 +-
datafusion/sqllogictest/test_files/aggregate.slt | 20 ++++
4 files changed, 125 insertions(+), 38 deletions(-)
diff --git
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
index 25f52df611..09e1df4eae 100644
---
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
+++
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
@@ -23,6 +23,7 @@ use arrow::array::{Array, BooleanArray, BooleanBufferBuilder,
PrimitiveArray};
use arrow::buffer::NullBuffer;
use arrow::datatypes::ArrowPrimitiveType;
+use crate::aggregate::groups_accumulator::nulls::filter_to_validity;
use datafusion_expr_common::groups_accumulator::EmitTo;
/// If the input has nulls, then the accumulator must potentially
@@ -471,7 +472,7 @@ pub fn accumulate<T, F>(
///
/// This method assumes that for any input record index, if any of the value
column
/// is null, or it's filtered out by `opt_filter`, then the record would be
ignored.
-/// (won't be accumulated by `value_fn`)
+/// (Won't be accumulated by `value_fn`)
///
/// # Arguments
///
@@ -491,35 +492,28 @@ pub fn accumulate_multiple<T, F>(
T: ArrowPrimitiveType + Send,
F: FnMut(usize, usize, &[&PrimitiveArray<T>]) + Send,
{
- // Calculate `valid_indices` to accumulate, non-valid indices are ignored.
- // `valid_indices` is a bit mask corresponding to the `group_indices`. An
index
- // is considered valid if:
- // 1. All columns are non-null at this index.
- // 2. Not filtered out by `opt_filter`
-
- // Take AND from all null buffers of `value_columns`.
- let combined_nulls = value_columns
- .iter()
- .map(|arr| arr.logical_nulls())
- .fold(None, |acc, nulls| {
- NullBuffer::union(acc.as_ref(), nulls.as_ref())
- });
-
- // Take AND from previous combined nulls and `opt_filter`.
- let valid_indices = match (combined_nulls, opt_filter) {
- (None, None) => None,
- (None, Some(filter)) => Some(filter.clone()),
- (Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(),
None)),
- (Some(nulls), Some(filter)) => {
- let combined = nulls.inner() & filter.values();
- Some(BooleanArray::new(combined, None))
- }
- };
-
for col in value_columns.iter() {
debug_assert_eq!(col.len(), group_indices.len());
}
+ // Start with rows where all value columns are non-null.
+ let mut valid_indices =
+ NullBuffer::union_many(value_columns.iter().map(|arr| arr.nulls()))
+ .map(NullBuffer::into_inner);
+
+ // Restrict to rows where the optional filter is Some(true). Keep the
filter
+ // as a raw BooleanBuffer to avoid computing a NullBuffer null_count just
to
+ // test row validity below.
+ if let Some(filter) = opt_filter {
+ debug_assert_eq!(filter.len(), group_indices.len());
+ let filter_validity = filter_to_validity(filter);
+ if let Some(valid_indices) = valid_indices.as_mut() {
+ *valid_indices &= &filter_validity;
+ } else {
+ valid_indices = Some(filter_validity);
+ }
+ }
+
match valid_indices {
None => {
for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
@@ -562,7 +556,8 @@ pub fn accumulate_indices<F>(
(None, Some(filter)) => {
debug_assert_eq!(filter.len(), group_indices.len());
let group_indices_chunks = group_indices.chunks_exact(64);
- let bit_chunks = filter.values().bit_chunks();
+ let filter_validity = filter_to_validity(filter);
+ let bit_chunks = filter_validity.bit_chunks();
let group_indices_remainder = group_indices_chunks.remainder();
@@ -636,7 +631,8 @@ pub fn accumulate_indices<F>(
let group_indices_chunks = group_indices.chunks_exact(64);
let valid_bit_chunks = valids.inner().bit_chunks();
- let filter_bit_chunks = filter.values().bit_chunks();
+ let filter_validity = filter_to_validity(filter);
+ let filter_bit_chunks = filter_validity.bit_chunks();
let group_indices_remainder = group_indices_chunks.remainder();
@@ -1188,6 +1184,68 @@ mod test {
assert_eq!(accumulated, expected);
}
+ #[test]
+ fn test_accumulate_indices_with_null_filter() {
+ let group_indices = vec![0, 1, 0, 1];
+ let filter = BooleanArray::new(
+ BooleanBuffer::from(vec![true, true, true, false]),
+ Some(NullBuffer::from(vec![true, false, true, true])),
+ );
+
+ let mut accumulated = vec![];
+ accumulate_indices(&group_indices, None, Some(&filter), |group_idx| {
+ accumulated.push(group_idx);
+ });
+
+ // A NULL filter value should be treated the same as false, even if the
+ // underlying BooleanBuffer value is true.
+ let expected = vec![0, 0];
+ assert_eq!(accumulated, expected);
+
+ let value_validity = NullBuffer::from(vec![true, true, false, true]);
+ let mut accumulated = vec![];
+ accumulate_indices(
+ &group_indices,
+ Some(&value_validity),
+ Some(&filter),
+ |group_idx| {
+ accumulated.push(group_idx);
+ },
+ );
+
+ let expected = vec![0];
+ assert_eq!(accumulated, expected);
+ }
+
+ #[test]
+ fn test_accumulate_multiple_with_null_filter() {
+ let group_indices = vec![0, 1, 0, 1];
+ let values1 = Int32Array::from(vec![1, 2, 3, 4]);
+ let values2 = Int32Array::from(vec![10, 20, 30, 40]);
+ let value_columns = [values1, values2];
+
+ let filter = BooleanArray::new(
+ BooleanBuffer::from(vec![true, true, true, false]),
+ Some(NullBuffer::from(vec![true, false, true, true])),
+ );
+
+ let mut accumulated = vec![];
+ accumulate_multiple(
+ &group_indices,
+ &value_columns.iter().collect::<Vec<_>>(),
+ Some(&filter),
+ |group_idx, batch_idx, columns| {
+ let values = columns.iter().map(|col|
col.value(batch_idx)).collect();
+ accumulated.push((group_idx, values));
+ },
+ );
+
+ // A NULL filter value should be treated the same as false, even if the
+ // underlying BooleanBuffer value is true.
+ let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
+ assert_eq!(accumulated, expected);
+ }
+
#[test]
fn test_accumulate_multiple_with_nulls_and_filter() {
let group_indices = vec![0, 1, 0, 1];
diff --git
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs
index 5b56b77e11..d524afe43a 100644
---
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs
+++
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs
@@ -22,7 +22,7 @@ use arrow::array::{
BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray,
StringArray,
StringViewArray, StructArray,
};
-use arrow::buffer::NullBuffer;
+use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::datatypes::DataType;
use datafusion_common::{Result, not_impl_err};
use std::sync::Arc;
@@ -39,15 +39,24 @@ pub fn set_nulls<T: ArrowNumericType + Send>(
PrimitiveArray::<T>::new(values, nulls).with_data_type(dt)
}
-/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer.
+/// Converts an aggregate filter expression to a validity bitmap.
+///
+/// The output is `true` for rows where the filter is `Some(true)`, and `false`
+/// for rows where the filter is `Some(false)` or `None`.
+pub(crate) fn filter_to_validity(filter: &BooleanArray) -> BooleanBuffer {
+ let Some(filter_nulls) = filter.nulls() else {
+ return filter.values().clone();
+ };
+ filter.values() & filter_nulls.inner()
+}
+
+/// Converts an aggregate filter expression to a `NullBuffer`.
///
/// The `NullBuffer` is
-/// * `true` (representing valid) for values that were `true` in filter
-/// * `false` (representing null) for values that were `false` or `null` in
filter
-pub fn filter_to_nulls(filter: &BooleanArray) -> Option<NullBuffer> {
- let (filter_bools, filter_nulls) = filter.clone().into_parts();
- let filter_bools = NullBuffer::from(filter_bools);
- NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref())
+/// * `true` (representing valid) for filter values that were `Some(true)`
+/// * `false` (representing null) for filter values that were `Some(false)` or
`None`
+pub fn filter_to_nulls(filter: &BooleanArray) -> NullBuffer {
+ NullBuffer::new(filter_to_validity(filter))
}
/// Compute an output validity mask for an array that has been filtered
@@ -97,7 +106,7 @@ pub fn filtered_null_mask(
opt_filter: Option<&BooleanArray>,
input: &dyn Array,
) -> Option<NullBuffer> {
- let opt_filter = opt_filter.and_then(filter_to_nulls);
+ let opt_filter = opt_filter.map(filter_to_nulls);
NullBuffer::union(opt_filter.as_ref(), input.nulls())
}
diff --git a/datafusion/functions-aggregate/src/array_agg.rs
b/datafusion/functions-aggregate/src/array_agg.rs
index 4365c436a5..24edaaff1f 100644
--- a/datafusion/functions-aggregate/src/array_agg.rs
+++ b/datafusion/functions-aggregate/src/array_agg.rs
@@ -776,7 +776,7 @@ impl GroupsAccumulator for ArrayAggGroupsAccumulator {
let offsets = OffsetBuffer::from_repeated_length(1, input.len());
// Filtered rows become null list entries, which merge_batch will skip.
- let filter_nulls = opt_filter.and_then(filter_to_nulls);
+ let filter_nulls = opt_filter.map(filter_to_nulls);
// With ignore_nulls, null values also become null list entries.
Without
// ignore_nulls, null values stay as [NULL] so merge_batch retains
them.
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index 70acff3cb7..b8009dfd57 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -693,6 +693,18 @@ from data
----
1
+# correlation_with_group_by_and_nullable_filter
+query IR rowsort
+SELECT g, corr(x, y) FILTER (WHERE b < 1) AS r
+FROM (VALUES
+ (0, 1.0, 1.0, CAST(NULL AS INT)),
+ (0, 2.0, 2.0, CAST(NULL AS INT)),
+ (0, 3.0, 4.0, 2)
+) AS t(g, x, y, b)
+GROUP BY g
+----
+0 NULL
+
# group correlation_query_with_nans_f32
query IR
select id, corr(f, b)
@@ -6177,6 +6189,14 @@ FROM test_table
----
2
+# count_with_group_by_and_nullable_filter
+query II rowsort
+SELECT g, COUNT(a) FILTER (WHERE b < 1) AS count_a
+FROM (VALUES (0, 1, CAST(NULL AS INT)), (0, 2, 2)) AS t(g, a, b)
+GROUP BY g
+----
+0 0
+
# query_with_and_without_filter
query III rowsort
SELECT
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]